大模型文本生成——解码策略(Top-k Top-p Temperature)

{"top_k": 10,"temperature": 0.95,"num_beams": 1,"top_p": 0.8,"repetition_penalty": 1.5,"max_tokens": 30000,"message": [{"content": "你好!","role": "user"}]
}

在大模型训练好之后,如何对训练好的模型进行解码(decode)是一个火热的研究话题。

在自然语言任务中,我们通常使用一个预训练的大模型(比如GPT)来根据给定的输入文本(比如一个开头或一个问题)生成输出文本(比如一个答案或一个结尾)。为了生成输出文本,我们需要让模型逐个预测每个 token ,直到达到一个终止条件(如一个标点符号或一个最大长度)。在每一步,模型会给出一个概率分布,表示它对下一个单词的预测。例如,如果输入的文本是“我最喜欢的”,那么模型可能会给出下面的概率分布:

那么,我们应该如何从这个概率分布中选择下一个单词呢?以下是几种常用的方法:

  • 贪心解码(Greedy Decoding):直接选择概率最高的单词。这种方法简单高效,但是可能会导致生成的文本过于单调和重复。
  • 随机采样(Random Sampling):按照概率分布随机选择一个单词。这种方法可以增加生成的多样性,但是可能会导致生成的文本不连贯和无意义。
  • Beam Search:维护一个大小为 k 的候选序列集合,每一步从每个候选序列的概率分布中选择概率最高的 k 个单词,然后保留总概率最高的 k 个候选序列。这种方法可以平衡生成的质量和多样性,但是可能会导致生成的文本过于保守和不自然。

以上方法都有各自的问题,而 top-k 采样和 top-p 采样是介于贪心解码和随机采样之间的方法,也是目前大模型解码策略中常用的方法。

top-k采样

在上面的例子中,如果使用贪心策略,那么选择的 token 必然就是“女孩”。

贪心解码是一种合理的策略,但也有一些缺点。例如,输出可能会陷入重复循环。想想智能手机自动建议中的建议。当你不断地选择建议最高的单词时,它可能会变成重复的句子。

Top-k 采样是对前面“贪心策略”的优化,它从排名前 k 的 token 中进行抽样,允许其他分数或概率较高的token 也有机会被选中。在很多情况下,这种抽样带来的随机性有助于提高生成质量。

top-k 采样的思路是,在每一步,只从概率最高的 k 个单词中进行随机采样,而不考虑其他低概率的单词。例如,如果 k=2,那么我们只从女孩、鞋子中选择一个单词,而不考虑大象、西瓜等其他单词。这样可以避免采样到一些不合适或不相关的单词,同时也可以保留一些有趣或有创意的单词。

下面是 top-k 采样的例子:

通过调整 k 的大小,即可控制采样列表的大小。“贪心策略”其实就是 k = 1的 top-k 采样。

下面是top-k 的代码实现:

import torch
from labml_nn.sampling import Sampler# Top-k Sampler
class TopKSampler(Sampler):# k is the number of tokens to pick# sampler is the sampler to use for the top-k tokens# sampler can be any sampler that takes a logits tensor as input and returns a token tensor; e.g. `TemperatureSampler`.def __init__(self, k: int, sampler: Sampler):self.k = kself.sampler = sampler# Sample from logitsdef __call__(self, logits: torch.Tensor):# New logits filled with −∞; i.e. zero probabilityzeros = logits.new_ones(logits.shape) * float('-inf')# Pick the largest k logits and their indicesvalues, indices = torch.topk(logits, self.k, dim=-1)# Set the values of the top-k selected indices to actual logits.# Logits of other tokens remain −∞zeros.scatter_(-1, indices, values)# Sample from the top-k logits with the specified sampler.return self.sampler(zeros)

总结一下,top-k 有以下有点:

  • 它可以根据不同的输入文本动态调整候选单词的数量,而不是固定为 k 个。这是因为不同的输入文本可能会导致不同的概率分布,有些分布可能比较平坦,有些分布可能比较尖锐。如果分布比较平坦,那么前 k 个单词可能都有相近的概率,那么我们就可以从中进行随机采样;如果分布比较尖锐,那么前 k 个单词可能会占据绝大部分概率,那么我们就可以近似地进行贪心解码。
  • 它可以通过调整 k 的大小来控制生成的多样性和质量。一般来说,k 越大,生成的多样性越高,但是生成的质量越低;k 越小,生成的质量越高,但是生成的多样性越低。因此,我们可以根据不同的任务和场景来选择合适的k 值。
  • 它可以与其他解码策略结合使用,例如温度调节(Temperature Scaling)、重复惩罚(Repetition Penalty)、长度惩罚(Length Penalty)等,来进一步优化生成的效果。

但是 top-k 也有一些缺点,比如:

  • 它可能会导致生成的文本不符合常识或逻辑。这是因为 top-k 采样只考虑了单词的概率,而没有考虑单词之间的语义和语法关系。例如,如果输入文本是“我喜欢吃”,那么即使饺子的概率最高,也不一定是最合适的选择,因为可能用户更喜欢吃其他食物。
  • 它可能会导致生成的文本过于简单或无聊。这是因为 top-k 采样只考虑了概率最高的 k 个单词,而没有考虑其他低概率但有意义或有创意的单词。例如,如果输入文本是“我喜欢吃”,那么即使苹果、饺子和火锅都是合理的选择,也不一定是最有趣或最惊喜的选择,因为可能用户更喜欢吃一些特别或新奇的食物。

因此,我们通常会考虑 top-k 和其它策略结合,比如 top-p。

top-p采样

top-k 有一个缺陷,那就是“k 值取多少是最优的?”非常难确定。于是出现了动态设置 token 候选列表大小策略——即核采样(Nucleus Sampling)。

top-p 采样的思路是,在每一步,只从累积概率超过某个阈值 p 的最小单词集合中进行随机采样,而不考虑其他低概率的单词。这种方法也被称为核采样(nucleus sampling),因为它只关注概率分布的核心部分,而忽略了尾部部分。例如,如果 p=0.9,那么我们只从累积概率达到 0.9 的最小单词集合中选择一个单词,而不考虑其他累积概率小于 0.9 的单词。这样可以避免采样到一些不合适或不相关的单词,同时也可以保留一些有趣或有创意的单词。

下图展示了 top-p 值为 0.9 的 Top-p 采样效果:

top-p 值通常设置为比较高的值(如0.75),目的是限制低概率 token 的长尾。我们可以同时使用 top-k 和 top-p。如果 k 和 p 同时启用,则 p 在 k 之后起作用。

下面是 top-p 代码实现的例子:

import torch
from torch import nnfrom labml_nn.sampling import Samplerclass NucleusSampler(Sampler):"""## Nucleus Sampler"""def __init__(self, p: float, sampler: Sampler):""":param p: is the sum of probabilities of tokens to pick $p$:param sampler: is the sampler to use for the selected tokens"""self.p = pself.sampler = sampler# Softmax to compute $P(x_i | x_{1:i-1})$ from the logitsself.softmax = nn.Softmax(dim=-1)def __call__(self, logits: torch.Tensor):"""Sample from logits with Nucleus Sampling"""# Get probabilities $P(x_i | x_{1:i-1})$probs = self.softmax(logits)# Sort probabilities in descending ordersorted_probs, indices = torch.sort(probs, dim=-1, descending=True)# Get the cumulative sum of probabilities in the sorted ordercum_sum_probs = torch.cumsum(sorted_probs, dim=-1)# Find the cumulative sums less than $p$.nucleus = cum_sum_probs < self.p# Prepend ones so that we add one token after the minimum number# of tokens with cumulative probability less that $p$.nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)# Get log probabilities and mask out the non-nucleussorted_log_probs = torch.log(sorted_probs)sorted_log_probs[~nucleus] = float('-inf')# Sample from the samplersampled_sorted_indexes = self.sampler(sorted_log_probs)# Get the actual indexesres = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))#return res.squeeze(-1)

Temperature采样

Temperature 采样受统计热力学的启发,高温意味着更可能遇到低能态。在概率模型中,logits 扮演着能量的角色,我们可以通过将 logits 除以温度来实现温度采样,然后将其输入 Softmax 并获得采样概率。

越低的温度使模型对其首选越有信心,而高于1的温度会降低信心。0温度相当于 argmax 似然,而无限温度相当于均匀采样。

Temperature 采样中的温度与玻尔兹曼分布有关,其公式如下所示:

\rho_i = \frac{1}{Q}e^{-\epsilon_i/kT}=\frac{e^{-\epsilon i/kT}}{\sum{j=1}^M e^{-\epsilon_j/kT}}\\

其中 

\rho_i

 是状态 

i

 的概率, 

\epsilon_i

 是状态 

i

 的能量, 

k

 是波兹曼常数, 

T

 是系统的温度, 

M

 是系统所能到达的所有量子态的数目。

有机器学习背景的朋友第一眼看到上面的公式会觉得似曾相识。没错,上面的公式跟 Softmax 函数 :

\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{c=1}^Ce^{z_c}}\\

很相似,本质上就是在 Softmax 函数上添加了温度(T)这个参数。Logits 根据我们的温度值进行缩放,然后传递到 Softmax 函数以计算新的概率分布。

上面“我喜欢漂亮的___”这个例子中,初始温度 

T=1

 ,我们直观看一下 

T

 取不同值的情况下,概率会发生什么变化:

通过上图我们可以清晰地看到,随着温度的降低,模型愈来愈越倾向选择”女孩“;另一方面,随着温度的升高,分布变得越来越均匀。当 

T=50

 时,选择”西瓜“的概率已经与选择”女孩“的概率相差无几了。

通常来说,温度与模型的“创造力”有关。但事实并非如此。温度只是调整单词的概率分布。其最终的宏观效果是,在较低的温度下,我们的模型更具确定性,而在较高的温度下,则不那么确定

下面是 Temperature 采样的代码实现:

import torch
from torch.distributions import Categoricalfrom labml_nn.sampling import Samplerclass TemperatureSampler(Sampler):"""## Sampler with Temperature"""def __init__(self, temperature: float = 1.0):""":param temperature: is the temperature to sample with"""self.temperature = temperaturedef __call__(self, logits: torch.Tensor):"""Sample from logits"""# Create a categorical distribution with temperature adjusted logitsdist = Categorical(logits=logits / self.temperature)# Samplereturn dist.sample()

联合采样(top-k & top-p & Temperature)

通常我们是将 top-k、top-p、Temperature 联合起来使用。使用的先后顺序是 top-k->top-p->Temperature。

我们还是以前面的例子为例。

首先我们设置 top-k = 3,表示保留概率最高的3个 token。这样就会保留女孩、鞋子、大象这3个 token。

  • 女孩:0.664
  • 鞋子:0.199
  • 大象:0.105

接下来,我们可以使用 top-p 的方法,保留概率的累计和达到 0.8 的单词,也就是选取女孩和鞋子这两个 token。接着我们使用 Temperature = 0.7 进行归一化,变成:

  • 女孩:0.660
  • 鞋子:0.340

接着,我们可以从上述分布中进行随机采样,选取一个单词作为最终的生成结果。

参考

https://nn.labml.ai/sampling/index.html

Huggingface的GenerationConfig 中的top_k与top_p详细解读

Temperature

ChatGPT模型采样算法详解-阿里云开发者社区

ChatGPT模型采样算法详解_JarodYv的博客-CSDN博客

大语言模型参数说明(Temperature,Top p,Top k)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/279135.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电子招投标系统:企业在招标前,需要考虑哪些事项?

招标过程可能非常复杂和耗时&#xff0c;这使得一些企业放弃招标寻源方式。然而&#xff0c;要发展业务和客户群&#xff0c;就不能逃避招标。 在进行招标过程之前&#xff0c;首先要打好基础。让我们来看看企业在设计招标流程时应考虑哪些事项。 1. 确保有购买意向和能力 在…

Vue el-table 合并单元格

一般常见的就是下图这种的单列&#xff0c;上下重复进行合并。 有时候可能也会需要多行多列的合并。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content&qu…

kkview远程控制: 内网远程桌面控制软件

内网远程桌面控制软件&#xff1a;高效、安全的远程管理方案 在信息技术日新月异的今天&#xff0c;内网远程桌面控制软件已成为许多企业和个人用户不可或缺的工具。这类软件允许用户通过内部网络&#xff0c;实现对其他计算机的远程访问和控制&#xff0c;从而大大提高工作效…

Docker 系列2【docker安装mysql】【开启远程连接】

文章目录 前言开始步骤1. 增加mysql挂载目录2.下载镜像3. 启动mysql容器4. 配置mysql5.无法连接5.测试连接 总结 前言 本文开始&#xff0c;默认已经安装docker&#xff0c;如果你还没有完成这个步骤&#xff0c;请查看这一篇文章【docker安装与使用】 开始步骤 1. 增加mysq…

<Linux> 生产者消费者模型

目录 前言&#xff1a; 一、什么是生产者消费者模型 &#xff08;一&#xff09;概念 &#xff08;二&#xff09;生产者消费者之间的关系 &#xff08;三&#xff09;生产者消费者模型特点 &#xff08;四&#xff09;生产者消费者模型的优点 二、基于阻塞队列实现生产…

论文阅读——MoCo

Momentum Contrast for Unsupervised Visual Representation Learning 动量在数学上理解为加权移动平均&#xff1a; yt-1是上一时刻输出&#xff0c;xt是当前时刻输入&#xff0c;m是动量&#xff0c;不想让当前时刻输出只依赖于当前时刻的输入&#xff0c;m很大时&#xff0…

鸿蒙Harmony应用开发—ArkTS声明式开发(画布组件:Path2D)

路径对象&#xff0c;支持通过对象的接口进行路径的描述&#xff0c;并通过Canvas的stroke接口或者fill接口进行绘制。 说明&#xff1a; 从 API Version 8 开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。 addPath addPath(path: path2D,…

【ARM】DS中Coretex-M处理器的常用寄存器介绍

【更多软件使用问题请点击亿道电子官方网站查询】 1、 文档目标 了解ArmDS中Coretex-M处理器的常用寄存器的名称及作用。 2、 问题场景 在对Coretex-M处理器进行开发时&#xff0c;了解常用寄存器的名称及作用&#xff0c;可以&#xff1a; 编写正确的程序: 寄存器是程序员用…

cs推免相关文书模板、基本资料

目录 复试问题 文书模板 机考指南 链接:https://pan.baidu.com/s/1WAAzTPZsASNDt5XRmAO9VA?pwd=21yk 提取码:21yk --来自百度网盘超级会员V5的分享 408专业课复习 链接:https://pan.baidu.com/s/1UI0EwWTy3zn3lm3wTQJ2Dw?pwd=t5gj 提取码:t5gj --来自百度网盘超级会…

【代码】YOLOv8标注信息验证

此代码的功能是标注信息验证&#xff0c;将原图和YOLOv8标注文件&#xff08;txt&#xff09;放在同一个文件夹中&#xff0c;作为输入文件夹 程序将标注的信息还原到原图中&#xff0c;并将原图和标注后的图像一同保存&#xff0c;以便查看 两个draw_labels函数&#xff0c;分…

HarmonyOS 通知意图

之前的文章 我们讲了 harmonyos 中的 基础和进度条通知 那么 今天 我们来说说 任何给通知添加意图 通知意图 简单说 就是 当我们点击某个通知 如下图 然后 就会拉起某个 应用 就例如说 我们某个微信好友发消息给我们 我们 点击系统通知 可以直接跳到你们的聊天界面 好 回到…

哥斯拉流量webshell分析-->ASP/PHP

哥斯拉流量webshell分析 哥斯拉是继菜刀、蚁剑、冰蝎之后的又一个webshell利器&#xff0c;这里就不过多介绍了。 哥斯拉GitHub地址&#xff1a;https://github.com/BeichenDream/Godzilla 很多一线师傅不太了解其中的加解密手法&#xff0c;无法进行解密&#xff0c;这篇文章…

JupyterNotebook 如何切换使用的虚拟环境kernel

在Jupyter Notebook中&#xff0c;如果需要修改使用的虚拟环境Kernel&#xff1a; 首先&#xff0c;需要确保虚拟环境已经安装conda上【conda基本操作】 打开Jupyter Notebook。 在Jupyter Notebook的顶部菜单中&#xff0c;选择 “New” 在弹出的窗口中&#xff0c;列出了…

亮相AWE 2024,日立中央空调打造定制空气新体验

日立中央空调于3月14日携旗下空气定制全新成果&#xff0c;亮相2024中国家电及消费电子博览会&#xff08;简称AWE 2024&#xff09;现场&#xff0c;围绕“科创先行 智引未来”这一主题&#xff0c;通过技术与产品向行业与消费者&#xff0c;展现自身对于家居空气的理解。 展会…

【零基础C语言】自定义类型:结构体

1.结构体的声明 struct tag {member - list; }variable-list; 我们描述一本书或者一个学生的时候可以这样写 //书 - 书名&#xff0c;作者&#xff0c;价格&#xff0c;序号struct book {char _book_name[20];char _author[20];int price;char id[20]; };//学生 - 姓名&…

从头手搓一台ros2复合机器人(带机械臂)

一.前言 大家好呀&#xff0c;从本小节开始我们就步入了仿真篇&#xff0c;主要对机器人仿真进行介绍与操作&#xff0c;当然仿真有优点也有缺陷&#xff0c;基于对此学习&#xff0c;我们可以对上几小节创建的小车模型模拟硬件的特性&#xff0c; 比如&#xff1a; 有多重…

web学习笔记(三十三)

目录 1.严格模式 1.1严格模式的概念&#xff1a; 1.2严格模式在语义上更改的地方&#xff1a; 1.3如何开启严格模式 1.4严格模式应用上的变化 2.原型链 1.严格模式 1.1严格模式的概念&#xff1a; 严格模式有点像es5向es6过渡而产生的一种模式&#xff0c;因为es6的语法…

Python-GEE绘制DEM精美图片

目录 上传矢量和DEM获取添加颜色条参考文章 先连接上GEE的自己的项目 import ee import geemap geemap.set_proxy(port33210) ee.Authenticate() ee.Initialize(projecta-flyllf0313)上传矢量和DEM获取 使用Google Earth Engine&#xff08;GEE&#xff09;和Google Earth Eng…

linux:线程互斥

个人主页 &#xff1a; 个人主页 个人专栏 &#xff1a; 《数据结构》 《C语言》《C》《Linux》 文章目录 前言一、线程互斥问题解释互斥量的接口 二、加锁的原理三、 死锁死锁四个必要条件避免死锁 总结 前言 本文是对于线程互斥的知识总结 一、线程互斥 问题 我们先看下面…

基础乱炖来吧

1,SSH框架和SSM区别 SSH&#xff1a;structspringhibernate&#xff0c;SSM&#xff1a;MVCspringmybatis struct入口是filter级别&#xff0c;对action类进行请求&#xff0c;一个action类对应一个请求、类拦截&#xff1b;spring-mvcservlet级别&#xff0c;方法级别请求&…