[论文阅读] Knowledge Fusion of Large Language Models

Knowledge Fusion of Large Language Models (FuseLLM)


Methodology

整体Pipeline如下图所示
在这里插入图片描述

不同的动物代表不同的LLM。左边第一,第二分别是Ensemble以及Weight Merging方法。最右侧为本文提出的FuseLLM。

  • Ensemble: 融合多个models的预测结果,比如求加权平均等。
  • Weight Merging:在权重/参数层面融合,但通常仅限于相同架构的模型。
  • FuseLLM 主要思想为:融合多个LLMs(可以是不同架构的)的probabilistic matrices,得到Fused Matrix后,喂给Target Model,起到知识蒸馏的作用

这里面会涉及到一个关键:

  • 不同LLM,使用的Tokenizer可能不同,设置也可能不一样(如 model_max_length ),分词结果可能不一样(比如对同一个句子分词,tokens总数不同),使用的Vocabulary也可能不一样,因此生成的probabilistic matrix在维度上可能有所不同,如何解决对齐问题?这个实际上就是 token alignment 问题,本文中着重描述了解决方案。

Definition of Problem

假设我们有一个语料库 C \mathcal{C} C K K K个source LLMs, 对于文本 t ∈ C t \in \mathcal{C} tC,经过 K K K个LLM处理,可以得到对应的概率分布矩阵 probabilistic distribution matrix { P t θ j } j = 1 K \{\mathbf{P}^{\theta_j}_t\}^K_{j=1} {Ptθj}j=1K,其中 θ j \theta_j θj表示第 j j j个LLM的参数。我们要做的就是将这 K K K概率分布矩阵融合,然后送入Target LLM中辅助训练:
P t = F u s i o n ( P t θ 1 , P t θ 2 , … , P t θ K ) , \begin{align} \mathbf{P}_t=\mathbb{F}\mathrm{usion}(\mathbf{P}_t^{\theta_1},\mathbf{P}_t^{\theta_2},\ldots,\mathbf{P}_t^{\theta_K}), \end{align} Pt=Fusion(Ptθ1,Ptθ2,,PtθK),
P t \mathbf{P}_t Pt即得到的融合概率分布矩阵(Fused Representation Matrix)。

为了将 P t \mathbf{P}_t Pt迁移至target model中,我们假设 Q t \mathbf{Q}_t Qt为其输出的representation matrix,则Knowledge Fusion的训练目标为:
L F u s i o n = − E t ∼ C [ D ( Q t , P t ) ] . \begin{align} \mathcal{L}_{\mathrm{Fusion}}=-\mathbb{E}_{t\sim\mathcal{C}}\left[\mathbb{D}(\mathbf{Q}_t,\mathbf{P}_t)\right]. \end{align} LFusion=EtC[D(Qt,Pt)].
其中 D ( ⋅ , ⋅ ) \mathbb{D}(\cdot, \cdot) D(,)表示差异性函数,具体实现可以是KL散度。
整体的模型损失如下:
L = λ L C L M + ( 1 − λ ) L F u s i o n . \begin{align}\mathcal{L}=\lambda\mathcal{L}_{\mathrm{CLM}}+(1-\lambda)\mathcal{L}_{\mathrm{Fusion}}.\end{align} L=λLCLM+(1λ)LFusion.
其中 L C L M \mathcal{L}_{\mathrm{CLM}} LCLM表示最原始的ground-truth之间的损失, λ \lambda λ为系数。

实现细节

Token Alignment

我们假设有两个LLM,使用不同的tokenizer。对同一段文本分词,得到的token序列不同,长度也不同:
在这里插入图片描述
如上图,用DeepSeek和TinyLlama各自的分词器分词,得到的结果完全不一样。最终预测的概率分布矩阵也不一样。

Token-Level Alignment

为了解决这个问题,FuseLLM采用基于最小编辑距离Minimal Edit Distance(MinED)的动态规划策略,在token-level实现对齐,以下图为例:
在这里插入图片描述
具体实现的源代码other.py如下:


def dtw(series_1, series_2, norm_func=np.linalg.norm):"""Use dynamic time wrapping to align to tokenizers, modified from:https://github.com/talcs/simpledtw/blob/master/simpledtw.py""""""Parameters----------series_1: List[str]blending_input_tokensseries_2: List[str]base_input_tokensnorm_func: functionedit distance evaluation between 2 tokensReturn Values----------matches: List[Tuple]matched pairs between a base token and a blending tokenmatrix[-1, -1]: int the total cost for mapping the two series of tokensmappings_series_1: List[List]mapping from blending tokens to base tokenseg: [0], [1, 2], [3, 4, 5], [6], ...mappings_series_2: List[List]mapping from base tokens to blending tokensmatrix: List[int]the dtw matrix"""matrix = np.zeros((len(series_1) + 1, len(series_2) + 1))matrix[0, :] = np.infmatrix[:, 0] = np.infmatrix[0, 0] = 0for i, vec1 in enumerate(series_1):for j, vec2 in enumerate(series_2):cost = norm_func(vec1, vec2)matrix[i + 1, j + 1] = cost + min(matrix[i, j + 1], matrix[i + 1, j], matrix[i, j])matrix = matrix[1:, 1:]i = matrix.shape[0] - 1j = matrix.shape[1] - 1matches = []mappings_series_1 = [list() for v in range(matrix.shape[0])]mappings_series_2 = [list() for v in range(matrix.shape[1])]while i > 0 or j > 0:matches.append((i, j))mappings_series_1[i].append(j)mappings_series_2[j].append(i)option_diag = matrix[i - 1, j - 1] if i > 0 and j > 0 else np.infoption_up = matrix[i - 1, j] if i > 0 else np.infoption_left = matrix[i, j - 1] if j > 0 else np.infmove = np.argmin([option_diag, option_up, option_left])if move == 0:i -= 1j -= 1elif move == 1:i -= 1else:j -= 1matches.append((0, 0))mappings_series_1[0].append(0)mappings_series_2[0].append(0)matches.reverse()for mp in mappings_series_1:mp.reverse()for mp in mappings_series_2:mp.reverse()return matches, matrix[-1, -1], mappings_series_1, mappings_series_2, matrix
Logit-Level Alignment

利用该对齐结果,将不同LLMs得到的representation matrix对齐。关键代码other.py如下:


def transform_step_logits(base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,blending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBase,base_model_vocab: Dict[str, int],base_model_input_ids: List[int],blending_model_input_ids: List[int],blending_model_per_step_logits: List[List[float]],blending_model_per_step_indices: List[List[int]],vocab_align_type: str = "hard",blending_to_base_mapping: Dict[str, str] = None,
):"""Align blending model per step logits & indices with base model.""""""Parameters----------base_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBaseblending_model_tokenizer: transformers.tokenization_utils_base.PreTrainedTokenizerBasebase_model_vocab: Dict[str, int]mapping token to id using vocabulary of base modelbase_model_input_ids: List[int]ids of base_model_input_tokensblending_model_input_ids: List[int]ids of blending_model_input_tokensblending_model_per_step_logits: List[List[float]]logits for each token in blending_model_input_tokens blending_model_per_step_indices: List[List[int]]indices corresponding to logits for each token in blending_model_input_tokens vocab_align_type: str = "hard"blending_to_base_mapping: Dict[str, str] = Nonemapping each blending token to its corresponding base token Return Values----------aligned_blending_model_per_step_logits: List[List[float]]aligned logits for each token in base_model_input_tokens for the FuseLLM trainingaligned_blending_model_per_step_indices: List[List[int]]aligned indices corresponding aligned logits for each token in base_model_input_tokens for the FuseLLM training. Use the base model vocabulary to look up the token."""base_model_tokens = base_model_tokenizer.convert_ids_to_tokens(base_model_input_ids)blending_model_tokens = blending_model_tokenizer.convert_ids_to_tokens(blending_model_input_ids)base_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[base_model_tokenizer.__class__]blending_model_special_token = TOKENIZER_TO_SPECIAL_TOKEN[blending_model_tokenizer.__class__]def dist_fn(a, b):"""Calculate editdistance between two tokens, a is from blending model, b is from base model."""aa = a.replace(blending_model_special_token, "")bb = b.replace(base_model_special_token, "")dist = editdistance.eval(aa, bb)return dist_, _, _, base_to_blending, _ = dtw(blending_model_tokens, base_model_tokens, norm_func=dist_fn)aligned_blending_model_per_step_logits, aligned_blending_model_per_step_indices = ([],[],)for i, blending_idx in enumerate(base_to_blending):aligned_blending_model_per_step_logit = []aligned_blending_model_per_step_index = []if len(blending_idx) == 1:  # one base token map to one blending tokenj = blending_idx[0]base_token = base_model_tokens[i]blending_token = blending_model_tokens[j].replace(blending_model_special_token, base_model_special_token)if ((blending_model_tokenizer.__class__== transformers.GPTNeoXTokenizerFastor blending_model_tokenizer.__class__== transformers.GPT2TokenizerFast)and i == 0and base_token.startswith(base_model_special_token)and not blending_token.startswith(base_model_special_token)):blending_token = (base_model_special_token + blending_token)  # special case for mptif vocab_align_type == "hard":if (base_token == blending_token):  # find the aligned mapping, use the corresponding logits# the logits and indices at this stepfor blending_logit, blending_index in zip(blending_model_per_step_logits[j],blending_model_per_step_indices[j],):# the token corresponds to the logit and indicesblending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(blending_model_special_token, base_model_special_token)if blending_t in base_model_vocab:aligned_index = base_model_vocab[blending_t]  # the index of the token in base model vocabif (aligned_indexnot in aligned_blending_model_per_step_index):aligned_blending_model_per_step_index.append(aligned_index)aligned_blending_model_per_step_logit.append(blending_logit)else:  # find error aligned mapping, use the one-hot logitsaligned_blending_model_per_step_index.append(base_model_vocab[base_token])aligned_blending_model_per_step_logit.append(1.0)elif vocab_align_type == "soft":if (base_token == blending_token) or (blending_token in blending_to_base_mappingand base_token == blending_to_base_mapping[blending_token]):  # find the aligned mapping, use the corresponding logits# the logits and indices at this stepfor blending_logit, blending_index in zip(blending_model_per_step_logits[j],blending_model_per_step_indices[j],):# the token corresponds to the logit and indicesblending_t = blending_model_tokenizer.convert_ids_to_tokens([blending_index])[0].replace(blending_model_special_token, base_model_special_token)blending_t = blending_to_base_mapping[blending_t]if blending_t in base_model_vocab:aligned_index = base_model_vocab[blending_t]  # the index of the token in base model vocabif (aligned_indexnot in aligned_blending_model_per_step_index):aligned_blending_model_per_step_index.append(aligned_index)aligned_blending_model_per_step_logit.append(blending_logit)else:logger.warning(f"blending_t: {blending_t} not in base_model_vocab!")else:  # find error aligned mapping, use the one-hot logitsaligned_blending_model_per_step_index.append(base_model_vocab[base_token])aligned_blending_model_per_step_logit.append(1.0)else:logger.warning(f"The vocab_align_type: '{vocab_align_type}' is not support!")raise NotImplementedErrorelse:  # one base token map to multiple blending token, in this case only fit base token. use the one-hot logitsbase_token = base_model_tokens[i]aligned_blending_model_per_step_index.append(base_model_vocab[base_token])aligned_blending_model_per_step_logit.append(1.0)aligned_blending_model_per_step_indices.append(aligned_blending_model_per_step_index)aligned_blending_model_per_step_logits.append(aligned_blending_model_per_step_logit)return (aligned_blending_model_per_step_logits,aligned_blending_model_per_step_indices,)

Fusion Strategies:

得到对其的representation matrix以后,由于不同的LLM具有不同的性能,可以使用概率分布矩阵与ground-truth之间的交叉熵损失(CE loss)评估LLM的优劣,再根据此判断选择哪些LLM参与知识融合。CE loss越低,证明模型效果更好。具体而言,作者提出了两种Fusion Strategy:

  1. MinCE: 仅选择CE loss最小的representation matrix用于知识融合。
  2. AvgCE: 基于各个模型的CE loss,采用多个representation matrices的加权平均,用于知识融合。

整体的算法流程如下:
在这里插入图片描述

  • 注:这里Eq.5实际是本文中上述的Eq.3

一些思考

本文的思路是将多个LLMs输出的概率分布矩阵视为知识,将知识融合后,送入target LLM进行训练,以达到融合多种模型知识,提升目标模型性能的目的。但在实际的实现当中我们会发现,logit-level的alignment,要么是直接采用blending_model_per_step_logits/indices,要么直接用ground-truth one-hot作为融合后的知识,而没有充分评估logit-level中,blending/base_model_per_step_logits之间的差异性。为此,Probabilistic Token Alignment for Large Language Model Fusion提出采用Probabilistic Token Alignment方法,在logit-level实现alignment。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/13989.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2024~2025学年佛山市普通高中教学质量检测(一)【高三数学】

一、选择题 本题共8小题&#xff0c;每小题5分&#xff0c;共40分。在每小题给出的四个选项中。只有一项是符合题目要求的。 1、若 5 z 2 i 1 \frac{5}{z}2i1 z5​2i1&#xff0c;则 z z z A. 1-2i B. 12i C. 2-i D. 2i2、已知集合 A { x ∣ 1 < x < a } A\left\{…

探索从传统检索增强生成(RAG)到缓存增强生成(CAG)的转变

在人工智能快速发展的当下&#xff0c;大型语言模型&#xff08;LLMs&#xff09;已成为众多应用的核心技术。检索增强生成&#xff08;RAG&#xff09;&#xff08;RAG 系统从 POC 到生产应用&#xff1a;全面解析与实践指南&#xff09;和缓存增强生成&#xff08;CAG&#x…

anaconda中可以import cv2,但是notebook中cv2 module not found

一、问题 anaconda中成功import cv2 但是jupyter notebook中却无法导入cv2 二、排查 anaconda中使用python路径如下&#xff1a; jupyter notebook中使用python路径如下&#xff1a; 可以发现路径不一致。 三、解决 ①查看可用的kernel ②选中想要修改的kernel&#xff0c;打…

【数据结构】_栈的结构与实现

目录 1. 栈的相关概念与结构 2. 栈的实现 2.1 栈实现的底层结构选择 2.2 Stack.h 2.3 Stack.c 2.4 Test_Stack.c 1. 栈的相关概念与结构 1、栈&#xff1a;一种特殊的线性表&#xff0c;只允许在固定的一端插入和删除数据&#xff1b; 允许进行数据插入和删除操作的一端…

mysql的cpu使用率100%问题排查

背景 线上mysql服务器经常性出现cpu使用率100%的告警&#xff0c; 因此整理一下排查该问题的常规流程。 1. 确认CPU占用来源 检查系统进程 使用 top 或 htop 命令&#xff0c;确认是否是 mysqld 进程导致CPU满载&#xff1a;top -c -p $(pgrep mysqld)2. 实时分析MySQL活动 …

某团面试题①—kudu读写流程

kudu 读写流程 前言 为什么会有kudu&#xff1f;先贴一个经典的图。 kudu诞生之前大数据的主要2种方式存储 静态数据 以hdfs引擎作为存储引擎&#xff0c;适用于高吞吐量的离线大数据分析场景&#xff0c;缺点是实现随机读写性能差&#xff0c;更新数据难 动态数据 以Hbase…

Deepseek本地部署指南:在linux服务器部署,在mac远程web-ui访问

1. 在Linux服务器上部署DeepSeek模型 要在 Linux 上通过 Ollama 安装和使用模型&#xff0c;您可以按照以下步骤进行操作&#xff1a; 步骤 1&#xff1a;安装 Ollama 安装 Ollama&#xff1a; 使用以下命令安装 Ollama&#xff1a; curl -sSfL https://ollama.com/download.…

go并发和并行

进程和线程 进程&#xff08;Process&#xff09;就是程序在操作系统中的一次执行过程&#xff0c;是系统进行资源分配和调度的基本单位&#xff0c;进程是一个动态概念&#xff0c;是程序在执行过程中分配和管理资源的基本单位&#xff0c;每一个进程都有一个自己的地址空间。…

element-ui rate 组件源码分享

评分组件&#xff0c;从三个方面分享&#xff1a; 1、页面结构。 2、组件属性。 3、组件方法。 一、页面结构&#xff1a; 主要有图标的、图标(默认或自定义图标)文字的、图标分数的。 二、属性。 2.1 value 2.2 max 最大分数。 2.3 disabled 是否只读 2.4 allow-half 是…

python学opencv|读取图像(五十六)使用cv2.GaussianBlur()函数实现图像像素高斯滤波处理

【1】引言 前序学习了均值滤波和中值滤波&#xff0c;对图像的滤波处理有了基础认知&#xff0c;相关文章链接为&#xff1a; python学opencv|读取图像&#xff08;五十四&#xff09;使用cv2.blur()函数实现图像像素均值处理-CSDN博客 python学opencv|读取图像&#xff08;…

HIVE如何注册UDF函数

如果注册UDF函数的时候报了上面的错误&#xff0c;说明hdfs上传的路径不正确&#xff0c; 一定要用下面的命令 hadoop fs -put /tmp/hive/111.jar /user/hive/warehouse 一定要上传到上面路径&#xff0c;这样在创建函数时&#xff0c;引用下面的地址就可以创建成功

紧跟潮流,将 DeepSeek 集成到 VSCode

Visual Studio Code&#xff08;简称 VSCode&#xff09;是一款由微软开发的免费开源代码编辑器&#xff0c;自 2015 年发布以来&#xff0c;凭借其轻便、强大、且拥有丰富扩展生态的特点&#xff0c;迅速成为了全球开发者的首选工具。VSCode 支持多平台操作系统&#xff0c;包…

HAL库 Systick定时器 基于STM32F103EZT6 野火霸道,可做参考

目录 1.时钟选择(这里选择高速外部时钟) ​编辑 2.调试模式和时基源选择: 3.LED的GPIO配置 这里用板子的红灯PB5 4.工程配置 5.1ms的systick中断实现led闪烁 源码: 6.修改systick的中断频率 7.systick定时原理 SysTick 定时器的工作原理 中断触发机制 HAL_SYSTICK_Co…

DeepSeek与llama本地部署(含WebUI)

DeepSeek从2025年1月起开始火爆&#xff0c;成为全球最炙手可热的大模型&#xff0c;各大媒体争相报道。我们可以和文心一言一样去官网进行DeepSeek的使用&#xff0c;那如果有读者希望将大模型部署在本地应该怎么做呢&#xff1f;本篇文章将会教你如何在本地傻瓜式的部署我们的…

【重新认识C语言----文件管理篇】

目录 ​编辑 -----------------------------------------begin------------------------------------- 引言 1. 文件的基本概念 2. 文件指针 3. 文件的打开与关闭 3.1 打开文件 3.2 关闭文件 4. 文件的读写操作 4.1 读取文件 4.1.1 使用fgetc()读取文件 4.1.2 使用fg…

全面解析String类

一、String 类初相识 在 C 语言的世界里&#xff0c;字符串是以\0结尾的字符集合&#xff0c;为了方便操作&#xff0c;C 标准库提供了一系列str系列的库函数&#xff0c;如strcpy、strcat、strlen等。虽然这些库函数在一定程度上满足了我们对字符串的操作需求&#xff0c;但是…

pycharm 中的 Mark Directory As 的作用是什么?

文章目录 Mark Directory As 的作用PYTHONPATH 是什么PYTHONPATH 作用注意事项 Mark Directory As 的作用 可以查看官网&#xff1a;https://www.jetbrains.com/help/pycharm/project-structure-dialog.html#-9p9rve_3 我们这里以 Mark Directory As Sources 为例进行介绍。 这…

MySQL - 字段内分组

1、MySQL 5.7及之前版本 SELECT A.要显示的字段名称,FIRST_VALUE : A.分组字段名称,last :IF(FIRST_VALUE A.分组字段名称, last 1, 1 ) AS rn,FROM 表1 A,(SELECT last : 0, FIRST_VALUE : NULL ) BORDER BY A.排序字段例&#xff1a;SELECT A.DLR_CODE,A.VAILD_CARD_NO,A.L…

瞬态分析中的时域分析与频域分析:原理、对比与应用指南

目录 一、核心概念区分 二、时域分析&#xff1a;时间维度直接求解 1. 基本原理 2. 关键特点 3. 典型算法 4. 应用案例 三、频域分析&#xff1a;频率维度的等效映射 1. 基本原理 2. 关键特点 3. 典型方法 4. 应用案例 四、对比与选择依据 1. 方法论对比 2. 工程…

【DeepSeek】DeepSeek小模型蒸馏与本地部署深度解析DeepSeek小模型蒸馏与本地部署深度解析

一、引言与背景 在人工智能领域&#xff0c;大型语言模型&#xff08;LLM&#xff09;如DeepSeek以其卓越的自然语言理解和生成能力&#xff0c;推动了众多应用场景的发展。然而&#xff0c;大型模型的高昂计算和存储成本&#xff0c;以及潜在的数据隐私风险&#xff0c;限制了…