文本生成模型如何解码

文章目录

    • 解码方法
      • Greedy Search
      • Beam Search
      • sampling
      • Temperature Sampling
      • top-k sampling
      • Top-p (nucleus) sampling
      • Contrastive search
    • 总结
    • 相关资源

语言模型如何对于一个给定输入生成相应的输出呢?答案是使用解码策略(decoding strategy)。这里对现有的解码策略做一个记录。

解码方法

与huggingface的how to generate 一样,用流行的transformers包和GPT2模型来对各个解码方法测试生成效果,先加载模型:

# transformers的安装命令: pip install -q transformers
# 导入对象
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch# 确定推理设备
torch_device = "cuda" if torch.cuda.is_available() else "cpu"# 加载分词器,第一次调用会先下载
tokenizer = AutoTokenizer.from_pretrained("gpt2")# 加载模型,第一次调用会先下载
# add the EOS token as PAD token to avoid warnings
model = AutoModelForCausalLM.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id).to(torch_device)

Greedy Search

Greedy Search贪心搜索就是在每个时间步,都选择概率最大的词汇作为下一个词。比如下面的图片,从词"The"开始,算法先贪心的选择概率最大的词"nice",接着选择概率最大的"women"。

在这里插入图片描述

如果使用transformers的generate函数来生成文本,不指定参数的话,默认就是使用贪心搜索。

# encode context the generation is conditioned on
model_inputs = tokenizer('I enjoy playing badminton', return_tensors='pt').to(torch_device)# generate 40 new tokens
greedy_output = model.generate(**model_inputs, max_new_tokens=40)print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, but I'm not a big fan of the idea of playing badminton. I think it's a bit too much of a distraction. I think it's a distraction that's not going to
  • 贪心搜索得到的最终序列不一定是最优的句子,因为最优的句子的前面的词的概率可能会比较低,但是句子整体的概率更高。就像上面图片中的[‘the’, ‘dog’, ‘has’] 的概率比[‘the’, ‘nice’, ‘women’]要高。
  • 从上面的示例结果中发现生成的内容有重复,这是语言模型生成存在的一个问题,在贪心搜索和beam search中会更常见。
  • 使用LLM生成结果时,有一个Temperature参数,比如openai 的 api ,当Temperature=0时就是使用的贪心搜索。

Beam Search

因为贪心搜索每次选择概率最大的词可能会错过整体概率更高的句子;为了减轻这个风险,Beam Search 通过在每个时间步保留num_beams个概率最高的词,最终选择整体概率最大的句子。

下面的图片示意了num_beams=2的情形:

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使num_beams>1并且do_sample=False(默认即为False),就是使用的beam search方法。

# activate beam search and early_stopping
beam_output = model.generate(**model_inputs,max_new_tokens=40,num_beams=5,early_stopping=True
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play badm

我们也可以尝试将beam search 生成的句子都打印出来(用参数return_num_sequences,注意要小于等于num_beams),可以发现生成的几个句子差别不太大。

# set return_num_sequences > 1
beam_outputs = model.generate(**model_inputs,max_new_tokens=40,num_beams=5,num_return_sequences=5,early_stopping=True
)# now we have 5 output sequences
print("Output:\n" + 100 * '-')
for i, beam_output in enumerate(beam_outputs):print("{}: {}".format(i, tokenizer.decode(beam_output, skip_special_tokens=True)))
Output:
----------------------------------------------------------------------------------------------------
0: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play badm
1: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton.
2: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I don't like to play goodm
3: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton."
4: I enjoy playing badminton, but I don't like to play badminton. I don't like to play badminton. I don't like to play badminton. I like to play badminton,
  • Beam Search 可以保证比贪心搜索生成概率更高的句子,但是仍然不能保证找到最有可能的句子。
  • Beam Search的重复句子生成可以用n-grams惩罚来减轻,n-gram惩罚保证每个n-gram不会出现两次,方法是如果看到当前候选词与其上文所组成的 n-gram 已经出现过了,就将该候选词的概率设置为 0 。transformers包可以使用参数no_repeat_ngram_sizeno_repeat_ngram_size=2就是任意2-gram不会出现两次。
  • 在机器翻译或摘要等任务中,因为所需生成的文本长度或多或少都是可预测的,所以beam search效果比较好 - 参见 Murray et al. (2018) 和 Yang et al. (2018)的工作。但开放域文本生成情况有所不同,其输出文本长度可能会有很大差异,如对话和故事生成的输出文本长度就有很大不同。

sampling

采样就意味着不确定性,它根据当前条件概率分布随机选择下一个词。也就是每一个单词都有一定的几率会被选择,比如上面的图片中的例子,可视化出来就如下图,单词”car"从条件概率分布P(w|"The")中被采样到,接下来"drive"从P(w|"the", "car")被采样。

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=Truetop_k=0,就是使用采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
from transformers import set_seed
set_seed(42)# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(**model_inputs,max_new_tokens=40,do_sample=True,top_k=0
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and coach would agree, it's a lot tougher than most other AC athletes. American hockey SalariesThe Miami Dolphins
  • sampling 方法的问题模型可能会生成一些不太连贯的胡言乱语

Temperature Sampling

我们知道softmax的表达式如下式
p i = e x p ( z i ) ∑ j = 1 N e x p ( z j ) p_i = \frac {exp(z_i)} {\sum^N_{j=1} exp(z_j)} pi=j=1Nexp(zj)exp(zi)
而带Temperature的softmax的表达式如下式:
p i = e x p ( z i / τ ) ∑ j = 1 N e x p ( z j / τ ) p_i = \frac {exp(z_i/\tau)} {\sum^N_{j=1} exp(z_j/\tau)} pi=j=1Nexp(zj/τ)exp(zi/τ)

Temperature=1时就是普通的softmax,加了temperature之后可以让原本的概率分布更加两级分化(Temperature<1)或更平缓(Temperature>1)。

用如下代码生成的下图可以直观感受一下Temperature的效果:
在这里插入图片描述

import math
from matplotlib import pyplot as plt
import numpy as np
import torchdef softmax(vec, temperature):"""turn vec into normalized probability"""sum_exp = sum(math.exp(x/temperature) for x in vec)return [math.exp(x/temperature)/sum_exp for x in vec]def main():vec = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]ts = [0.1, 0.3, 0.6, 1, 1.5, 10, 100, 10000]for t in ts:result = softmax(vec, t)print(t, result)plt.plot(result, label=t)plt.legend()plt.show()if __name__ == "__main__":main()
-----------输出结果-----------------------
0.1 [8.193640616392913e-40, 1.8047694477191753e-35, 3.975269250769863e-31, 8.75611321772293e-27, 1.9286622828562907e-22, 4.2481613803067925e-18, 9.357198133414645e-14, 2.0610600462088695e-09, 4.5397868608862414e-05, 0.9999546000702376]
0.3 [9.023799189303686e-14, 2.5295175399808997e-12, 7.090648684486909e-11, 1.987624041824023e-09, 5.5716331571752974e-08, 1.5618193071184212e-06, 4.37803329701724e-05, 0.0012272338715773265, 0.034401359545912634, 0.964326006652751]
0.6 [2.48124849643664e-07, 1.3136945477127512e-06, 6.9553427122218854e-06, 3.682499278746801e-05, 0.00019496955792188005, 0.0010322643845619335, 0.00546531351351773, 0.028936048020019006, 0.15320161834191354, 0.811124444027169]
1 [7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866]
1.5 [0.0012076552782540224, 0.002352191295314716, 0.0045814430569569645, 0.008923432599188675, 0.017380473436496794, 0.03385253976191134, 0.06593574407043169, 0.12842529324824872, 0.25013831539204334, 0.4872029118611537]
10 [0.06120702456008912, 0.0676442235257524, 0.07475842861647011, 0.08262084118795704, 0.09131015090787675, 0.10091332330848407, 0.11152647016690201, 0.12325581142409142, 0.136218738269722, 0.150544988032655]
100 [0.09556032473672185, 0.09652072196694327, 0.09749077134979559, 0.09847056989102544, 0.09946021557130351, 0.10045980735602247, 0.10146944520519384, 0.10248923008344388, 0.10351926397011023, 0.10455964986943994]
10000 [0.09995500600033737, 0.0999650020007291, 0.09997499900077084, 0.09998499700056258, 0.09999499600020427, 0.10000499599979594, 0.10001499699943757, 0.10002499899922916, 0.10003500199927076, 0.10004500599966237]

如果使用transformers的generate函数来生成文本,使do_sample=True时,可以设置Temperature参数(默认值为1),比如使temperature=0.6:

# set seed to reproduce results. Feel free to change the seed though to get different results
from transformers import set_seed
set_seed(42)# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(**model_inputs,max_new_tokens=40,do_sample=True,top_k=0,temperature=0.6
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton, and I was delighted to have the opportunity to play against the best players from the world."I'm looking forward to the challenge of playing against some of the best players from the country
  • 可以发现将Temperature降低(例子从1变成0.6)后,因为将分布变得更两极化(增加高概率单词的可能性,降低低概率词的可能性),所以这次的生成内容更连贯了。如果还是用前一节的可视化例子的话,示意图类似如下

    在这里插入图片描述

  • 当设置 T e m p e r a t u r e → 0 Temperature \rightarrow 0 Temperature0时temperature采样也就等同于贪心搜索,比如在LLAMA代码中temperature=0时就是用的贪心搜索

top-k sampling

论文《Hierarchical Neural Story Generation》中提出top-k sampling方法 ,它在每个时间步先选出K个最可能的下一个词,将它们的概率进行缩放调整后在这K个词中进行采样。在GPT-2的论文中生成故事的时候就是使用的top-k采样方法。

将前面的例子中的下一个词从3个扩展到10个来可视化top-k sampling,设k=6,如下图所示:

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=Truetop_k>0,就是使用top-k采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
set_seed(42)# set top_k to 50
sample_output = model.generate(**model_inputs,max_new_tokens=40,do_sample=True,top_k=50
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and I would much rather spend my time here. I play much more than most American hockey players and I appreciate the community.
  • top-k 采样的结果看起来更自然

  • top-k采样的问题是因为不能动态调整单词的个数,有时候会像上图右图一样包括一些不太适合的词。

Top-p (nucleus) sampling

top-p采样方法出自论文《The Curious Case of Neural Text Degeneration》, 它在每个时间步,选出累积概率和超过概率p的最小单词集,将它们的概率进行缩放调整后在这个单词集中进行采样。这样得到的单词集的大小会根据下一个词的概率分布动态增加或减少。

比如如果设p=0.92,与前面top-k采样中同样的例子,如下图所示进行采样的候选词集是不一样的

在这里插入图片描述

如果使用transformers的generate函数来生成文本,使do_sample=True0<top_p<1,就是使用top-p采样方式解码:

# set seed to reproduce results. Feel free to change the seed though to get different results
set_seed(42)# set top_p to 0.92
sample_output = model.generate(**model_inputs,max_new_tokens=40,do_sample=True,top_p=0.92,top_k=0
)print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy playing badminton more than any other sport. I know more about winning than any other athlete and coach would agree, it's a lot tougher than most other sports because everyone is playing badminton. So I'm

在LLAMA的生成代码中,top-p的实现如下:

def sample_top_p(probs, p):"""Perform top-p (nucleus) sampling on a probability distribution.Args:probs (torch.Tensor): Probability distribution tensor.p (float): Probability threshold for top-p sampling.Returns:torch.Tensor: Sampled token indices.Note:Top-p sampling selects the smallest set of tokens whose cumulative probability massexceeds the threshold p. The distribution is renormalized based on the selected tokens."""probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)probs_sum = torch.cumsum(probs_sort, dim=-1)mask = probs_sum - probs_sort > pprobs_sort[mask] = 0.0probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))next_token = torch.multinomial(probs_sort, num_samples=1)next_token = torch.gather(probs_idx, -1, next_token)return next_token

Contrastive search

待学习总结,可参考huggingface blog。

总结

每种解码方法各有优点,都有适应的场景,可根据实际测试情况选择最适合自己的方法。

相关资源

  1. huggingface transformers关于文本生成的文档:

    • how to generate (本文笔记中的大部分代码和图片来自此文)
    • 生成相关文档的GitHub issue讨论
    • transformers里的解码策略
    • transformers 文本生成相关的类的说明文档
  2. https://nn.labml.ai/sampling/index.html

  3. https://finisky.github.io/illustrated-decoding-strategies/

  4. https://blog.csdn.net/muyao987/article/details/125917234

  5. openai 的 api文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/127145.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

用python实现基本数据结构【02/4】

*说明 如果需要用到这些知识却没有掌握&#xff0c;则会让人感到沮丧&#xff0c;也可能导致面试被拒。无论是花几天时间“突击”&#xff0c;还是利用零碎的时间持续学习&#xff0c;在数据结构上下点功夫都是值得的。那么Python 中有哪些数据结构呢&#xff1f;列表、字典、集…

【白话机器学习系列】白话梯度下降

白话梯度下降 梯度下降是机器学习中最常见的优化算法之一。理解它的基本实现是理解所有基于它构建的高级优化算法的基础。 文章目录 优化算法一维梯度下降均方误差梯度下降什么是均方误差单权重双权重三权重三个以上权重 矩阵求导结论 优化算法 在机器学习中&#xff0c;优化是…

ChatGPT实战与私有化大模型落地

文章目录 大模型现状baseline底座选择数据构造迁移方法评价思考 领域大模型训练技巧Tokenizer分布式深度学习数据并行管道并行向量并行分布式框架——Megatron-LM分布式深度学习框架——Colossal-AI分布式深度学习框架——DeepSpeedP-tuning 微调 资源消耗模型推理加速模型推理…

Python批处理(一)提取txt中数据存入excel

Python批处理&#xff08;一&#xff09;提取txt中数据存入excel 问题描述 现从冠层分析软件中保存了叶面积指数分析的结果&#xff0c;然而软件保存格式为txt&#xff0c;且在不同的文件夹中&#xff0c;每个文件夹的txt文件数量不固定&#xff0c;但是txt文件格式固定。现需…

C#__多线程之任务和连续任务

/// <summary> /// /// 任务&#xff1a;System.Threading.Tasks&#xff08;异步编程的一种实现方式&#xff09; /// 表应完成某个单元工作。这个工作可以在单独的线程中运行&#xff0c;也可以以同步方式启动一个任务。 /// /// 连续任务&#…

thinkphp6-简简单单地开发接口

目录 1.前言TP6简介 2.项目目录3.运行项目运行命令访问规则 4.model db使用db连接配置model编写及调用调用接口 5.返回json格式 1.前言 基于上篇文章环境搭建后&#xff0c;便开始简单学习上手开发接口…记录重要的过程&#xff01; Windows-试用phpthink发现原来可这样快速搭…

如何使用SQL SERVER的OpenQuery

如何使用SQL SERVER的OpenQuery 一、OpenQuery使用说明二、 OpenQuery语法2.1 参数说明2.2注解 三、示例3.1 执行 SELECT 传递查询3.2 执行 UPDATE 传递查询3.3 执行 INSERT传递查询3.4 执行 DELETE 传递查询 一、OpenQuery使用说明 在指定的链接服务器上执行指定的传递查询。 …

电工什么是电动势

什么是电动势&#xff1f;及电源电动势计算公式与方向确定 前面我们讲到在基本电路中的电流和电压的基础知识&#xff0c;而本文要讲的电动势和电压是一个很类似的概念。那么什么是电动势&#xff1f;电源电动势的计算公式是什么&#xff1f;它的方向如何确定及与电压有什么区…

轻量容器引擎Docker基础使用

轻量容器引擎Docker Docker是什么 Docker 是一个开源项目&#xff0c;诞生于 2013 年初&#xff0c;最初是 dotCloud 公司内部的一个业余项目。 它基于 Google 公司推出的 Go 语言实现&#xff0c;项目后来加入了 Linux 基金会&#xff0c;遵从了 Apache 2.0 协议&#xff0c;…

【Redis】深入探索 Redis 的数据类型 —— 哈希表 hash

文章目录 前言一、hash 类型相关命令1.1 HSET 和 HSETNX1.2 HGET 和 HMGET1.3 HKEYS、HVALS 和 HGETALL1.4 HEXISTS 和 HDEL1.5 HLEN1.6 HINCRBY 和 HINCRBYFLOAT1.7 哈希相关命令总结 二、hash 类型内部编码三、hash 类型的应用场景四、原生&#xff0c;序列化&#xff0c;哈希…

Android相机-架构3

目录 引言 1. Android相机的整体架构 2. 相机 HAL 2.1 AIDL相机HAL 2.2 相机 HAL3 功能 3. HAL子系统 3.1 请求 3.2 HAL和相机子系统 3.2.1 相机的管道 3.2.2 使用 Android Camera API 的步骤 3.2.3 HAL 操作摘要 3.3 启动和预期操作顺序 3.3.1 枚举、打开相机设备…

C语言课程作业

本科期间c语言课程作业代码整理&#xff1a; Josephus链表实现 Josephus 层序遍历树 二叉树的恢复 哈夫曼树 链表的合并 中缀表达式 链接&#xff1a;https://pan.baidu.com/s/1Q7d-LONauNLi7nJS_h0jtw?pwdswit 提取码&#xff1a;swit

《TCP/IP网络编程》阅读笔记--进程间通信

目录 1--进程间通信 2--pipe()函数 3--代码实例 3-1--pipe1.c 3-2--pipe2.c 3-3--pipe3.c 3-4--保存信息的回声服务器端 1--进程间通信 为了实现进程间通信&#xff0c;使得两个不同的进程间可以交换数据&#xff0c;操作系统必须提供两个进程可以同时访问的内存空间&am…

MySQL之MHA高可用配置及故障切换

目录 一、MHA概念 1、MHA的组成 2、MHA的特点 3、主从复制有多少种复制方法 二、搭建MySqlMHA部署 1&#xff0e;Master、Slave1、Slave2 节点上安装 mysql 2&#xff0e;修改 Master、Slave1、Slave2 节点的 Mysql主配置文件/etc/my.cnf 3. 配置 mysql 一主两从 4、安…

关于el-input和el-select宽度不一致问题解决

1. 情景一 单列布局 对于上图这种情况&#xff0c;只需要给el-select加上style"width: 100%"即可&#xff0c;如下&#xff1a; <el-select v-model"fjForm.region" placeholder"请选择阀门类型" style"width: 100%"><el-o…

【轻量化网络】MobileNet系列

MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications, CVPR2017 论文&#xff1a;https://arxiv.org/abs/1704.04861 代码&#xff1a; 解读&#xff1a;【图像分类】2017-MobileNetV1 CVPR_說詤榢的博客-CSDN博客 MobileNetV2: Inverted …

如何使用PySide2将designer设计的ui文件加载到Python类上鼠标拖拽显示路径

应用场景&#xff1a; designer快速设计好UI文件后&#xff0c;需要增加一些特别的界面功能&#xff0c;如文件拖拽显示文件路径功能。 方法如下&#xff1a; from PySide2.QtWidgets import QApplication, QMainWindow from PySide2.QtUiTools import loadUiTypeUi_MainWindo…

Java中wait和notify详解

线程的调度是无序的&#xff0c;随机的&#xff0c;但是也是有一定的需求场景&#xff0c;希望能够有序执行&#xff0c;join算是一种控制顺序的方式&#xff08;功能有限&#xff09;——》一个线程执行完&#xff0c;才能执行另一个线程&#xff01; 本文主要讲解的&#xf…

【工具使用】Dependency Walker使用

一&#xff0c;简介 在工作过程中常常会遇到编译的dll库运行不正常的情况&#xff0c;那就需要确认dll库是否编译正常&#xff0c;即是否将函数编译到dll中去。今天介绍一种查看dll库中函数定义的工具——Dependency walker。 二&#xff0c;软件介绍 Dependency Walker是一…

CSS3技巧36:backdrop-filter 背景滤镜

CSS3 有 filter 滤镜属性&#xff0c;能给内容&#xff0c;尤其是图片&#xff0c;添加各种滤镜效果。 filter 滤镜详见博文&#xff1a;CSS3中强大的filter(滤镜)属性_css3滤镜_stones4zd的博客-CSDN博客 后续&#xff0c;CSS3 又新增了 backdrop-filter 背景滤镜。 backdr…