LangChain之关于RetrievalQA input_variables 的定义与使用

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩,也就是RAG(Retrieval Augmented Generation)。
这部分的内容其实在LangChain的官网已经给出了流程图。在这里插入图片描述
我这里就直接偷懒了,准备对Webui的项目进行复刻练习,那么接下来就是照着葫芦画瓢就行。
那么我卡在了Retrieve这一步。先放有疑惑地方的代码:

if web_content:prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。已知网络检索内容:{web_content}""" + """已知内容:{context}问题:{question}"""else:prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。已知内容:{context}问题:{question}"""prompt = PromptTemplate(template=prompt_template,input_variables=["context", "question"])......knowledge_chain = RetrievalQA.from_llm(llm=self.llm,retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),prompt=prompt)knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(input_variables=["page_content"], template="{page_content}")knowledge_chain.return_source_documents = Trueresult = knowledge_chain({"query": query})return result

我对prompt_templateknowledge_chain.combine_documents_chain.document_prompt result = knowledge_chain({"query": query})这三个地方的input_key不明白为啥一定要这样设置。虽然我也看了LangChain的API文档。但是我并未得到详细的答案,那么只能一行行看源码是到底怎么设置的了。

注意:由于LangChain是一层层封装的,那么result = knowledge_chain({"query": query})可以认为是最外层,那么我们先看最外层。

result = knowledge_chain({“query”: query})

其实这部分是直接与用户的输入问题做对接的,我们只需要定位到RetrievalQA这个类就可以了,下面是RetrievalQA这个类的实现:

class RetrievalQA(BaseRetrievalQA):"""Chain for question-answering against an index.Example:.. code-block:: pythonfrom langchain.llms import OpenAIfrom langchain.chains import RetrievalQAfrom langchain.vectorstores import FAISSfrom langchain.schema.vectorstore import VectorStoreRetrieverretriever = VectorStoreRetriever(vectorstore=FAISS(...))retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)"""retriever: BaseRetriever = Field(exclude=True)def _get_docs(self,question: str,*,run_manager: CallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())async def _aget_docs(self,question: str,*,run_manager: AsyncCallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())@propertydef _chain_type(self) -> str:"""Return the chain type."""return "retrieval_qa"

可以看到其继承了BaseRetrievalQA这个父类,同时对_get_docs这个抽象方法进行了实现。

这里要扩展的说一下,_get_docs这个方法就是利用向量相似性,在vector Base中选择与embedding之后的query最近似的Document结果。然后作为RetrievalQA的上下文。具体只需要看BaseRetrievalQA这个方法的_call和就可以了。
接下来我们只需要看BaseRetrievalQA这个类的属性就可以了。

class BaseRetrievalQA(Chain):"""Base class for question-answering chains."""combine_documents_chain: BaseCombineDocumentsChain"""Chain to use to combine the documents."""input_key: str = "query"  #: :meta private:output_key: str = "result"  #: :meta private:return_source_documents: bool = False"""Return the source documents or not."""……def _call(self,inputs: Dict[str, Any],run_manager: Optional[CallbackManagerForChainRun] = None,) -> Dict[str, Any]:"""Run get_relevant_text and llm on input query.If chain has 'return_source_documents' as 'True', returnsthe retrieved documents as well under the key 'source_documents'.Example:.. code-block:: pythonres = indexqa({'query': 'This is my query'})answer, docs = res['result'], res['source_documents']"""_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()question = inputs[self.input_key]accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters)if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)else:docs = self._get_docs(question)  # type: ignore[call-arg]answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child())if self.return_source_documents:return {self.output_key: answer, "source_documents": docs}else:return {self.output_key: answer}

可以看到其有input_key这个属性,默认值是"query"。到这里我们就可以看到result = knowledge_chain({"query": query})是调用的BaseRetrievalQA_call,这里的question = inputs[self.input_key]就是其体现。

knowledge_chain.combine_documents_chain.document_prompt

这个地方一开始我很奇怪,为什么会重新定义呢?
我们可以先定位到,combine_documents_chain这个参数的位置,其是StuffDocumentsChain的方法。

@classmethod
def from_llm(cls,llm: BaseLanguageModel,prompt: Optional[PromptTemplate] = None,callbacks: Callbacks = None,**kwargs: Any,
) -> BaseRetrievalQA:"""Initialize from LLM."""_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)document_prompt = PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")combine_documents_chain = StuffDocumentsChain(llm_chain=llm_chain,document_variable_name="context",document_prompt=document_prompt,callbacks=callbacks,)return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

可以看到原始的document_prompt中PromptTemplate的template是“Context:\n{page_content}”。因为这个项目是针对中文的,所以需要将英文的Context去掉。

扩展

  1. 这里PromptTemplate(input_variables=[“page_content”], template=“Context:\n{page_content}”)的input_variablestemplate为什么要这样定义呢?其实是根据Document这个数据对象来定义使用的,我们可以看到其数据格式为:Document(page_content=‘……’, metadata={‘source’: ‘……’, ‘row’: ……})
    那么input_variables的输入就是Document的page_content。
  2. StuffDocumentsChain中有一个参数是document_variable_name。那么这个类是这样定义的This chain takes a list of documents and first combines them into a single string. It does this by formatting each document into a string with the document_prompt and then joining them together with document_separator. It then adds that new string to the inputs with the variable name set by document_variable_name. Those inputs are then passed to the llm_chain. 这个document_variable_name简单来说就是在document_prompt中的占位符,用于在Chain中的使用。
    因此我们上文prompt_template变量中的“已知内容: {context}”,用的就是context这个变量。因此在prompt_template中换成其他的占位符都不能正常使用这个Chain。

prompt_template

在上面的拓展中其实已经对prompt_template做了部分的讲解,那么这个字符串还剩下“问题:{question}”这个地方没有说通
还是回归源码:

return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

我们可以在from_llm函数中看到其返回值是到了_call,那么剩下的我们来看这个函数:


......
uestion = inputs[self.input_key]
accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)
else:docs = self._get_docs(question)  # type: ignore[call-arg]
answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
......

这里是在run这个函数中传入了一个字典值,这个字典值有三个参数。

注意:

  1. 这三个参数就是kwargs,也就是_validate_inputs的参数input;
  2. 此时已经是在Chain这个基本类了)
def run(self,*args: Any,callbacks: Callbacks = None,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,**kwargs: Any,) -> Any:"""Convenience method for executing chain.The main difference between this method and `Chain.__call__` is that thismethod expects inputs to be passed directly in as positional arguments orkeyword arguments, whereas `Chain.__call__` expects a single input dictionarywith all the inputs"""

接下来调用__call__:

def __call__(self,inputs: Union[Dict[str, Any], Any],return_only_outputs: bool = False,callbacks: Callbacks = None,*,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,run_name: Optional[str] = None,include_run_info: bool = False,) -> Dict[str, Any]:"""Execute the chain.Args:inputs: Dictionary of inputs, or single input if chain expectsonly one param. Should contain all inputs specified in`Chain.input_keys` except for inputs that will be set by the chain'smemory.return_only_outputs: Whether to return only outputs in theresponse. If True, only new keys generated by this chain will bereturned. If False, both input keys and new keys generated by thischain will be returned. Defaults to False.callbacks: Callbacks to use for this chain run. These will be called inaddition to callbacks passed to the chain during construction, but onlythese runtime callbacks will propagate to calls to other objects.tags: List of string tags to pass to all callbacks. These will be passed inaddition to tags passed to the chain during construction, but onlythese runtime tags will propagate to calls to other objects.metadata: Optional metadata associated with the chain. Defaults to Noneinclude_run_info: Whether to include run info in the response. Defaultsto False.Returns:A dict of named outputs. Should contain all outputs specified in`Chain.output_keys`."""inputs = self.prep_inputs(inputs)......

这里的prep_inputs会调用_validate_inputs函数

def _validate_inputs(self,inputs: Dict[str, Any]) -> None:"""Check that all inputs are present."""missing_keys = set(self.input_keys).difference(inputs)if missing_keys:raise ValueError(f"Missing some input keys: {missing_keys}")

这里的input_keys通过调试,看到的就是有多个输入,分别是"input_documents"和"question"
这里的"input_documents"是来自于BaseCombineDocumentsChain

class BaseCombineDocumentsChain(Chain, ABC):"""Base interface for chains combining documents.Subclasses of this chain deal with combining documents in a variety ofways. This base class exists to add some uniformity in the interface these typesof chains should expose. Namely, they expect an input key related to the documentsto use (default `input_documents`), and then also expose a method to calculatethe length of a prompt from documents (useful for outside callers to use todetermine whether it's safe to pass a list of documents into this chain or whetherthat will longer than the context length)."""input_key: str = "input_documents"  #: :meta private:output_key: str = "output_text"  #: :meta private:

那为什么有两个呢,“question”来自于哪里?
StuffDocumentsChain继承BaseCombineDocumentsChain,其input_key是这样定义的:

  @propertydef input_keys(self) -> List[str]:extra_keys = [k for k in self.llm_chain.input_keys if k != self.document_variable_name]return super().input_keys + extra_keys

原来是重写了input_keys函数,其是对llm_chain的input_keys进行遍历。

那么llm_chain的input_keys是用其prompt的input_variables。(这里的input_variables是PromptTemplate中的[“context”, “question”])

	@propertydef input_keys(self) -> List[str]:"""Will be whatever keys the prompt expects.:meta private:"""return self.prompt.input_variables

至此,我们StuffDocumentsChain的input_keys有两个变量了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/185697.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【QEMU-tap-windows-Xshell】QEMU 创建 aarch64虚拟机(附有QEMU免费资源)

“从零开始:在Windows上创建aarch64(ARM64)虚拟机” 前言 aarch64(ARM64)架构是一种现代的、基于 ARM 技术的计算架构,具有诸多优点,如低功耗、高性能和广泛应用等。为了在 Windows 平台上体验…

(欧拉)openEuler系统添加网卡文件配置流程、(欧拉)openEuler系统手动配置ipv6地址流程、(欧拉)openEuler系统网络管理说明

文章目录 系统说明openEuler23.03系统手动配置ip流程修改名称生成网卡配置文件【openEuler23.03系统添加网卡文件配置流程】手动指定ip添加ipv6地址修改配置文件信息和名称删除创建的网卡信息重启网卡生效并测试 openEuler23.03系统网络管理说明 系统说明 我这用云上最小化安装…

【Python大数据笔记_day05_Hive基础操作】

一.SQL,Hive和MapReduce的关系 用户在hive上编写sql语句,hive把sql语句转化为MapReduce程序去执行 二.Hive架构映射流程 用户接口: 包括CLI、JDBC/ODBC、WebGUI,CLI(command line interface)为shell命令行;Hive中的Thrift服务器允许外部客户端…

Python+Selenium+Unittest 之selenium12--WebDriver操作方法2-鼠标操作1(ActionChains类简介)

在我们平时的使用过程中,会使用鼠标去进行很多操作,比如鼠标左键点击、双击、鼠标右键点击,鼠标指针悬浮、拖拽等操作。在selenium中,我们也可以去实现常用的这些鼠标操作,这时候就需要用到selenium中的ActionChains类…

SQL第三次上机作业

1.查询与王利就读同一专业学生的借书证号和姓名 USE TSGL GO SELECT Lno,Rname FROM Reader WHERE Dept(SELECT DeptFROM ReaderWHERE Rname王利) and Rname ! 王利2.查询比希望出版社出版的所有图书价格都高的图书信息 SELECT * FROM Book WHERE Price>(SELECT MAX(Price…

5G-DFS最新动态-产品不在需要走FCC官方测试

添加图片注释,不超过 140 字(可选) 最近,FCC公布了最新版本的PAG(Product Acceptance Group)清单,即388624 D02 Pre-Approval Guidance List v18r04。这个清单的主要改变是将带有雷达侦测功能的…

AVL树详解

目录 AVL树的概念 旋转的介绍 单旋转 双旋转 旋转演示 具体实现 通过高度判断的实现 通过平衡因子判断的实现 AVL树的概念 AVL树是一种自平衡的平衡二叉查找树,它是一种高效的数据结构,可以在插入和删除节点时保持树的平衡,从而保证…

【容器化】Docker

文章目录 概述环境配置的难题虚拟机Linux 容器Docker 核心概念安装命令启动与停止命令镜像相关命令容器相关命令 部署MySQL 部署Tomcat 部署Nginx 部署Redis 部署 迁移与备份Dockerfile 制作镜像Docker 私有仓库将镜像上传到私有仓库从私有仓库拉取镜像 来源 概述 环境配置的难…

pyspark将数据多次插入表的时候报错

代码 报错信息 py4j.protocol.Py4JJavaError: An error occurred while calling o129.sql. : org.apache.spark.sql.catalyst.parser.ParseException: mismatched input INSERT expecting <EOF>(line 12, pos 0) 原因 插入语句结束后没有加&#xff1b;结尾 把两个&am…

原子化 CSS 真能减少体积么?

前言 最近看到这样一篇文章&#xff1a;《要喷也得先做做功课吧&#xff1f;驳Tailwind不好论》 个人觉得说的还是有一定道理的&#xff0c;就是该作者的语气态度可能稍微冲了点&#xff1a; 不过他说的确实有道理&#xff0c;如果这种原子化工具真的如评论区里那帮人说的那么…

asp.net core mvc之路由

一、默认路由 &#xff08;Startup.cs文件&#xff09; routes.MapRoute(name: "default",template: "{controllerHome}/{actionIndex}/{id?}" ); 默认访问可以匹配到 https://localhost:44302/home/index/1 https://localhost:44302/home/index https:…

idea使用gradle教程 (idea gradle springboot)2024

这里白眉大叔&#xff0c;写一下我工作时候idea怎么使用gradle的实战步骤吧 ----windows 环境----------- 1-本机安装gradle 环境 &#xff08;1&#xff09;下载gradle Gradle需要JDK的支持&#xff0c;安装Gradle之前需要提前安装JDK8及以上版本 https://downloads.gra…

【遮天】叶凡首次高燃时刻,暴打姜峰逼其下跪,故事逐渐燃情

Hello,小伙伴们&#xff0c;我是小郑继续为大家深度解析国漫资讯。 深度爆料&#xff0c;《遮天》国漫30集剧情最新内容解析&#xff0c;前面剧情中&#xff0c;叶凡被姜峰如疯狗一般追杀&#xff0c;他像一只被狼群追逐的鹿&#xff0c;在山林中亡命逃窜。身后是姜峰那歇斯底…

el-date-picker精确到分钟

0 效果 1 代码 使用format、value-format属性格式化即可 :clearable“false” // 取消删除图标 注意&#xff1a; format&#xff1a;“yyyy-MM-dd HH:mm” 小时默认是从00:00开始 format&#xff1a;“yyyy-MM-dd hh:mm” 小时默认是从12:00开始

torch.cumprod实现累乘计算

cumprod取自“cumulative product”的缩写&#xff0c;即“累计乘法”。 数学公式为&#xff1a; y i x 1 x 2 x 3 . . . x i y_ix_1\times{x_2}\times{x_3}\times{...}\times{x_i} yi​x1​x2​x3​...xi​ 官方链接&#xff1a;torch.cumprod 用法&#xff1a; impo…

代码随想录训练营Day1:二分查找与移除元素

本专栏内容为&#xff1a;代码随想录训练营学习专栏&#xff0c;用于记录训练营的学习经验分享与总结。 文档讲解&#xff1a;代码随想录 视频讲解&#xff1a;二分查找与移除元素 &#x1f493;博主csdn个人主页&#xff1a;小小unicorn ⏩专栏分类&#xff1a;C &#x1f69a…

基于CLIP的图像分类、语义分割和目标检测

OpenAI CLIP模型是一个创造性的突破&#xff1b; 它以与文本相同的方式处理图像。 令人惊讶的是&#xff0c;如果进行大规模训练&#xff0c;效果非常好。 在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D…

算法进阶指南图论 道路与航线

其实再次看这题的时候。想法就是和强连通分量有关&#xff0c;我们很容易发现&#xff0c;题目中所说的双向边&#xff0c;就构成了一个强连通分量&#xff0c;而所谓的单向边&#xff0c;则相当于把强连通分量进行缩点&#xff0c;然后整个图成为了一个DAG&#xff0c;众所周知…

go程序获取工作目录及可执行程序存放目录的方法-linux

简介 工作目录 通常就是指用户启动应用程序时&#xff0c;用户当时所在的文件夹的绝对路径。 如&#xff1a;root用户登录到linux系统后&#xff0c;一顿cd&#xff08;change directory&#xff09;后, 到了/tmp文件夹下。此时&#xff0c;用户要启动某个应用程序&#xff0…