简单的二元语言模型bigram实现

  • 内容总结归纳自视频:【珍藏】从头开始用代码构建GPT - 大神Andrej Karpathy 的“神经网络从Zero到Hero 系列”之七_哔哩哔哩_bilibili 
  • 项目:https://github.com/karpathy/ng-video-lecture

Bigram模型是基于当前Token预测下一个Token的模型。例如,如果输入序列是`[A, B, C]`,那么模型会根据`A`预测`B`,根据`B`预测`C`,依此类推,实现自回归生成。在生成新Token时,通常只需要最后一个Token的信息,因为每个预测仅依赖于当前Token。

1. 训练batch数据形式

训练数据是:

训练目标是使得损失函数:-log(p)最小,当前token是24时,下一个字符是43的概率-log(p[0,targets[0])=-log(p[0],43])损失最小,即概率p[0,43]最大:

2. 定义词嵌入层

nn.Embedding 层输出的是可学习的浮点数,将token索引 (B,T) 直接映射为logits,即输入(4,8),输出 (4,8,65),其中输入每个数字,被映射成logit向量(这些值通过 F.cross_entropy 内部自动进行 softmax 转换为概率分布),比如上面输入tokens有个24被映射成如下。

logits = [1.0, 0.5, -2.0, ..., 3.2]  # 共65个浮点数

softmax后得到。

probs = [0.15, 0.12, 0.01, ..., 0.20]  # 和为1的概率分布

这样输出的是每个位置的概率分布。

交叉熵函数会自动计算每个位置的概率分布与真实标签之间的损失,并取平均。

简单的大语言模型,基于Bigram的结构,即每个token仅根据前一个token来预测下一个token。具体实现如下。

from torch.nn import functional as F  # 导入PyTorch函数模块
torch.manual_seed(1337)  # 固定随机种子保证结果可复现class BigramLanguageModel(nn.Module):  # 定义Bigram语言模型类def __init__(self, vocab_size):super().__init__()  # 继承父类初始化方法# 定义词嵌入层:将token索引直接映射为logits# 输入输出维度均为vocab_size(词汇表大小)self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets):# 前向传播函数# idx: inputs, 输入序列 (B, T),B=批次数,T=序列长度# targets: 目标序列 (B, T)# (4,8) -> (4,8,65)logits = self.token_embedding_table(idx)  # (B, T, C)# 通过嵌入层获得每个位置的概率分布,C=词汇表大小# (4*8,C), (4*8,) -> (1,)B, T, C = logits.shape  # 解包维度:批次数、序列长度、词表大小logits = logits.view(B*T, C)  # 展平为二维张量 (B*T, C)targets = targets.view(B*T)    # 目标展平为一维张量 (B*T)loss = F.cross_entropy(logits, targets)  # 计算交叉熵损失return logits, loss  # 返回logits(未归一化概率)和损失值# 假设 vocab_size=65(例如52字母+标点)
vocab_size = 65
m = BigramLanguageModel(vocab_size)  # 实例化模型# 假设输入数据(代码中未定义):
# xb: 输入批次 (B=4, T=8),例如 tensor([[1,2,3,...], ...])
# yb: 目标批次 (B=4, T=8)
logits, loss = m(xb, yb)  # 执行前向传播print(logits.shape)  # 输出logits形状:torch.Size([32, 65])
# 解释:32 = B*T = 4*8,65=词表大小(每个位置65种可能)print(loss)  # 输出损失值:tensor(4.8786, grad_fn=<NllLossBackward>)
# 解释:初始随机参数下,损失值约为-ln(1/65)=4.17,实际值因参数初始化略有波动

3. 代码逻辑分步解释

# 假设输入和目标的形状均为 (B=4, T=8)
# 输入示例(第一个样本):
inputs[0] = [24, 43, 58, 5, 57, 1, 46, 43]
targets[0] = [43, 58, 5, 57, 1, 46, 43, 39]

3.1 Softmax后的概率分布意义 

当模型处理输入序列时,每个位置会输出一个长度为vocab_size的logits向量。即

输入: (4,8);

输出:(4,8,65). 65维度向量是每个输入token的下一个token的概率分布。

例如,当输入序列为 [24, 43, 58, 5, 57, 1, 46, 43] 时:
- 在第1个位置(token=24),模型预测下一个token(对应target=43)的概率分布p[0].shape=(65,),下一个输出是43的概率为p[0][target[0]]=p[0][43];
- 在第2个位置(token=43),模型预测下一个token(对应target=58)的概率分布p[1].shape=(65),下一个输出是58的概率为p[1][target[1]]=p[1][58];
- 以此类推,每个位置的logits经过softmax后得到一个概率分布,即每个输入位置,都会预测下一个token概率分布
具体来说:
   logits.shape = (4, 8, 65) → softmax后形状不变->p.shape=(4,8,65),但每行的65个值变为概率(和为1)
这些概率表示模型认为「当前token的下一个token」是词汇表中各token的可能性。

3.2 交叉熵计算步骤

假设logits初始形状为 (4, 8, 65)
B, T, C = logits.shape  # B=4, T=8, C=65

# 展平logits和targets:
logits_flat = logits.view(B*T, C)  # 形状 (32, 65)
targets_flat = targets.view(B*T)    # 形状 (32,)

# 交叉熵计算(PyTorch内部过程):
# 对logits_flat的每一行(共32行)做softmax,得到概率分布probs (32, 65)
# 对每个样本i,取probs[i][targets_flat[i]],即真实标签对应的预测概率(此概率是下一个token是targets_flat[i]的概率
# 计算负对数损失:loss = -mean(log(probs[i][targets_flat[i]]))(pytorch实现是将targets_flat所谓索引
loss = F.cross_entropy(logits_flat, targets_flat)  # 输出标量值

3.3 示例计算

# 以第一个样本的第一个位置为例:
# 输入token=24,目标token=43
# 模型输出的logits[0,0]是一个65维向量(这里logits.shape=[4,8,65]),例如:
logits_example = logits[0,0]  # 形状 (65,)
probs_example = F.softmax(logits_example, dim=-1)  # 形状 (65,)
# 假设probs_example[43] = 0.15(模型预测下一个token=43的概率为15%)
# 则此位置的损失为 -log(0.15) ≈ 1.897 (注意-log(p)是一个x范围在[0,1]之间单调递减函数)

# 最终损失是所有32个位置类似计算的均值。
# 初始损失约为4.87(接近均匀分布的理论值 -ln(1/65)≈4.17)

4. 测试生成文本

# super simple bigram model
class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()# each token directly reads off the logits for the next token from a lookup tableself.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets=None):# idx and targets are both (B,T) tensor of integerslogits = self.token_embedding_table(idx) # (B,T,C)if targets is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.view(B*T, C)targets = targets.view(B*T)loss = F.cross_entropy(logits, targets)return logits, lossdef generate(self, idx, max_new_tokens):# idx is (B, T) array of indices in the current contextfor _ in range(max_new_tokens):# get the predictionslogits, loss = self(idx) # 没有输入target时,返回的logits未被展平。  # focus only on the last time steplogits = logits[:, -1, :] # (B,T,C) -> (B, C)# apply softmax to get probabilitiesprobs = F.softmax(logits, dim=-1) # (B, C)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1) # (B, 1)# append sampled index to the running sequenceidx = torch.cat((idx, idx_next), dim=1) # (B, T+1)return idx

以下是Bigram模型生成过程的逐步详解,以输入序列[24, 43, 58, 5, 57, 1, 46, 43]为例,说明模型如何从初始输入[24]开始逐步预测下一个词: 

4.1 初始输入:[24]

  • 输入形状idx = [[24]]B=1批次,T=1序列长度)。

  • 前向传播

    • 通过嵌入层,模型输出logits形状为(1, 1, 65),表示对当前词24的下一个词的预测分数。

    • 假设logits[0, 0, 43] = 5.0(词43的logit较高),其他位置logits较低(如logits[0, 0, :] = [..., 5.0, ...])。

  • 概率分布

    • 对logits应用softmax,得到概率分布probs。例如:

      probs = [0.01, ..., 0.8(对应43), 0.01, ...]  # 总和为1
  • 采样

    • 根据probs,使用torch.multinomial采样,选中词43的概率最大。

  • 更新输入

    • 43拼接到序列末尾,新输入为idx = [[24, 43]](形状(1, 2))。


4.2 输入:[24, 43]

  • 前向传播

    • 模型处理整个序列,输出logits形状为(1, 2, 65),对应两个位置的预测:

      • 第1个位置(词24)预测下一个词(已生成43)。

      • 第2个位置(词43)预测下一个词。

    • 提取最后一个位置的logits:logits[:, -1, :](形状(1, 65))。

    • 假设logits[0, -1, 58] = 6.0(词58的logit较高)。

  • 概率分布

    • probs = [0.01, ..., 0.85(对应58), 0.01, ...]

  • 采样

    • 选中词58

  • 更新输入

    • 新输入为idx = [[24, 43, 58]](形状(1, 3))。


4.3 输入:[24, 43, 58]

  • 前向传播

    • logits形状为(1, 3, 65)

    • 提取最后一个位置(词58)的logits,假设logits[0, -1, 5] = 4.5

  • 概率分布

    • probs = [0.01, ..., 0.7(对应5), ...]

  • 采样

    • 选中词5

  • 更新输入

    • 新输入为idx = [[24, 43, 58, 5]](形状(1, 4))。


4.4 重复生成直到序列完成

  • 后续步骤

    • 输入[24, 43, 58, 5] → 预测词57

    • 输入[24, 43, 58, 5, 57] → 预测词1

    • 输入[24, 43, 58, 5, 57, 1] → 预测词46

    • 输入[24, 43, 58, 5, 57, 1, 46] → 预测词43

  • 最终序列

    • idx = [[24, 43, 58, 5, 57, 1, 46, 43]]

注意:上面输入序列是越来越长的,为何说预测下一个词只跟上一个词有关?如果只跟一个词有关,为何不每次只输入一个词,然后预测下一个词?

虽然理论上可以仅传递最后一个词,但实际实现中传递完整序列的原因(视频作者说的,固定generate函数形式,我这里理解的是代码简洁):

  • 代码简洁性:无需在每次生成时截取最后一个词,直接复用统一的前向传播逻辑;

实验验证

若修改代码,每次仅传递最后一个词:

def generate(self, idx, max_new_tokens):for _ in range(max_new_tokens):last_token = idx[:, -1:]          # 仅取最后一个词 (B, 1)logits, _ = self(last_token)       # 输出形状 (B, 1, C)probs = F.softmax(logits[:, -1, :], dim=-1)idx_next = torch.multinomial(probs, num_samples=1)idx = torch.cat((idx, idx_next), dim=1)return idx

4.5 完整代码

import torch
import torch.nn as nn
from torch.nn import functional as F# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 3000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
# ------------torch.manual_seed(1337)# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:text = f.read()# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]# data loading
def get_batch(split):# generate a small batch of data of inputs x and targets ydata = train_data if split == 'train' else val_dataix = torch.randint(len(data) - block_size, (batch_size,))x = torch.stack([data[i:i+block_size] for i in ix])y = torch.stack([data[i+1:i+block_size+1] for i in ix])x, y = x.to(device), y.to(device)return x, y@torch.no_grad()
def estimate_loss():out = {}model.eval()for split in ['train', 'val']:losses = torch.zeros(eval_iters)for k in range(eval_iters):X, Y = get_batch(split)logits, loss = model(X, Y)losses[k] = loss.item()out[split] = losses.mean()model.train()return out# super simple bigram model
class BigramLanguageModel(nn.Module):def __init__(self, vocab_size):super().__init__()# each token directly reads off the logits for the next token from a lookup tableself.token_embedding_table = nn.Embedding(vocab_size, vocab_size)def forward(self, idx, targets=None):# idx and targets are both (B,T) tensor of integerslogits = self.token_embedding_table(idx) # (B,T,C)if targets is None:loss = Noneelse:B, T, C = logits.shapelogits = logits.view(B*T, C)targets = targets.view(B*T)loss = F.cross_entropy(logits, targets)return logits, lossdef generate(self, idx, max_new_tokens):# idx is (B, T) array of indices in the current contextfor _ in range(max_new_tokens):# get the predictionslogits, loss = self(idx)# focus only on the last time steplogits = logits[:, -1, :] # becomes (B, C)# apply softmax to get probabilitiesprobs = F.softmax(logits, dim=-1) # (B, C)# sample from the distributionidx_next = torch.multinomial(probs, num_samples=1) # (B, 1)# append sampled index to the running sequenceidx = torch.cat((idx, idx_next), dim=1) # (B, T+1)return idxmodel = BigramLanguageModel(vocab_size)
m = model.to(device)# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)for iter in range(max_iters):# every once in a while evaluate the loss on train and val setsif iter % eval_interval == 0:losses = estimate_loss()print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")# sample a batch of dataxb, yb = get_batch('train')# evaluate the losslogits, loss = model(xb, yb)optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/30962.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

猫耳大型活动提效——组件低代码化

1. 引言 猫耳前端在开发活动的过程中&#xff0c;经历过传统的 pro code 阶段&#xff0c;即活动页面完全由前端开发编码实现&#xff0c;直到 2020 年接入公司内部的低代码活动平台&#xff0c;满足了大部分日常活动的需求&#xff0c;运营可自主配置活动并上线&#xff0c;释…

数据库基础以及基本建库建表的简单操作

文章目录 一、数据库是啥1.1、数据库的概念1.1、关系型数据库、非关系型数据库1.1、数据库服务器&#xff0c;数据库与表之间的关系 二、为啥要使用数据库2.1&#xff1a;传统数据文件存储2.2&#xff1a;数据库存储数据2.3、结论 三、使用数据库了会咋样四、应该咋用数据库&am…

常用无功功率算法的C语言实现(二)

0 前言 尽管数字延迟法和积分移相法在不间断采样的无功功率计算中得到了广泛应用,但它们仍存在一些固有缺陷。 对于数字延迟法而言,其需要额外存储至少1/4周期的采样点,在高采样频率的场景下,这对存储资源的需求不可忽视。而积分移相法虽然避免了额外的存储开销,但为了抑制…

【Linux】初识线程

目录 一、什么是线程&#xff1a; 重定义线程和进程&#xff1a; 执行流&#xff1a; Linux中线程的实现方案&#xff1a; 二、再谈进程地址空间 三、小结&#xff1a; 1、概念&#xff1a; 2、进程与线程的关系&#xff1a; 3、线程优点&#xff1a; 4、线程…

【单片机】ARM 处理器简介

ARM 公司简介 ARM&#xff08;Advanced RISC Machine&#xff09; 是英国 ARM 公司&#xff08;原 Acorn RISC Machine&#xff09; 开发的一种精简指令集&#xff08;RISC&#xff09; 处理器架构。ARM 处理器因其低功耗、高性能、广泛适用性&#xff0c;成为嵌入式系统、移动…

​​《从事件冒泡到处理:前端事件系统的“隐形逻辑”》

“那天在document见到你的第一眼&#xff0c;我就下定决心要陪你到天荒地老” ---React 我将从事件从出现到被处理的各个过程来介绍事件机制&#xff1a; 这张图片给我们展示了react事件的各个阶段&#xff0c;我们可以看到有DOM&#xff0c;合成事件层&#xff0c;还有…

Django小白级开发入门

1、Django概述 Django是一个开放源代码的Web应用框架&#xff0c;由Python写成。采用了MTV的框架模式&#xff0c;即模型M&#xff0c;视图V和模版T。 Django 框架的核心组件有&#xff1a; 用于创建模型的对象关系映射为最终用户设计较好的管理界面URL 设计设计者友好的模板…

课程《Deep Learning Specialization》

在coursera上&#xff0c;Deep Learning Specialization 课程内容如下图所示&#xff1a;

Java【网络原理】(3)网络编程续

目录 1.前言 2.正文 2.1ServerSocket类 2.2Socket类 2.3Tcp回显服务器 2.3.1TcpEchoServer 2.3.2TcpEchoClient 3.小结 1.前言 哈喽大家好&#xff0c;今天继续进行计算机网络的初阶学习&#xff0c;今天学习的是tcp回显服务器的实现&#xff0c;正文开始 2.正文 在…

安装remixd,在VScode创建hardhat

在终端&#xff0c;以管理员身份&#xff0c;cmd 需要科学上网 npm install -g remix-project/remixd 在vscode插件中&#xff0c;安装solidity插件&#xff0c;是暗灰色那款 1.将nodeJs的版本升级至18以上 2.在vscode打开一个新的文件&#xff0c;在终端输入 npx hardhat 3.…

微服务拆分-远程调用

我们在查询购物车列表的时候&#xff0c;它有一个需求&#xff0c;就是不仅仅要查出购物车当中的这些商品信息&#xff0c;同时还要去查到购物车当中这些商品的最新的价格和状态信息&#xff0c;跟购物车当中的快照进行一个对比&#xff0c;从而去提醒用户。 现在我们已经做了服…

TCP/IP 5层协议簇:网络层(ICMP协议)

1. TCP/IP 5层协议簇 如下&#xff1a; 和ip协议有关的才有ip头 2. ICMP 协议 ICMP协议没有端口号&#xff0c;因为不去上层&#xff0c;上层协议采用端口号

Uniapp 页面返回不刷新?两种方法防止 onShow 触发多次请求!

目录 前言1. 变量&#xff08;不生效&#xff09;2. 延迟&#xff08;生效&#xff09; 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF 在 Uniapp 中&#xff0c;使用 onShow() 钩子来监听页面显示&#xff0…

java_了解反射机制

目录 1. 定义 2. 用途 3. 反射基本信息 4. 反射相关的类 4.1 class类&#xff08;反射机制的起源&#xff09; 4.1.1 Class类中的相关方法&#xff08;方法的具体使用在后面的示例中&#xff09; 4.2 反射的示例 4.2.1 获得Class对象的三种方式 4.2.2 反射的使用 Fiel…

基于Python的商品销量的数据分析及推荐系统

一、研究背景及意义 1.1 研究背景 随着电子商务的快速发展&#xff0c;商品销售数据呈现爆炸式增长。这些数据中蕴含着消费者行为、市场趋势、商品关联等有价值的信息。然而&#xff0c;传统的数据分析方法难以处理海量、多源的销售数据&#xff0c;无法满足现代电商的需求。…

P8662 [蓝桥杯 2018 省 AB] 全球变暖--DFS

P8662 [蓝桥杯 2018 省 AB] 全球变暖--dfs 题目 解析讲下DFS代码 题目 解析 这道题的思路就是遍历所有岛屿&#xff0c;判断每一块陆地是否会沉没。对于这种图的遍历&#xff0c;我们首先应该想到DFS。 代码的注意思想就是&#xff0c;在主函数中遍历找出所有岛屿&#xff0c…

tiktok web登录 分析

声明: 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 逆向分析 部分代码 response reques…

邮件发送器:使用 Python 构建带 GUI 的邮件自动发送工具

在本篇博客中&#xff0c;我们将深入解析一个使用 wxPython 构建的邮件发送器 GUI 程序。这个工具能够自动查找指定目录中的文件作为附件&#xff0c;并提供邮件发送功能。本文将从功能、代码结构、关键技术等方面进行详细分析。 C:\pythoncode\new\ATemplateFromWeekReportByM…

pyside6学习专栏(十一):在PySide6中实现一简易的画板程序

pyside6学习专栏(十一):在PySide6中实现一简易的画板程序&#xff0c;实现画直线、矩形、填充矩形、圆、椭圆、随手画、文本。为减少代码量&#xff0c;所画形状的颜色、线宽、线型、填充形状、字体、字号等采用随机方式&#xff0c;只作为学习在Python中绘画的基本操作。 主界…

Android 屏幕适配 Tips

概念 屏幕尺寸&#xff1a;屏幕的对角线的长度屏幕分辨率&#xff1a;屏幕分辨率是指在横纵向上的像素点数&#xff0c;单位是px&#xff0c;1px1个像素点。一般以纵向像素x横向像素&#xff0c;如1960x1080屏幕像素密度&#xff1a;每英寸上的像素点数&#xff0c;单位是dpi …