2023年的深度学习入门指南(26) - 在自己电脑上运行通义千问7b模型

2023年的深度学习入门指南(26) - 在自己电脑上运行通义千问7b模型

通过量化,通义千问4位量化的模型大小为5.86G,可以在3060等小于16G的家用GPU上也可以运行起来。

通义千问7b的量化运行

通义千问7b提供了4位量化好的Qwen/Qwen-7B-Chat-Int4模型,我们直接调用就好。

首先安装依赖包:

pip install transformers==4.32.0
pip install accelerate
pip install tiktoken
pip install einops
pip install transformers_stream_generator==0.0.4
pip install scipy
pip install auto-gptq optimum

如果你是Linux环境的话,可以安装下Flash-Attention来加速:

git clone -b v1.0.8 https://github.com/Dao-AILab/flash-attention
cd flash-attention && pip install .

Windows下暂时还用不了,这个不是必选步骤。

下面我们就可以来写代码调用通义千问7b了:

from transformers import AutoTokenizer, AutoModelForCausalLM# Note: The default behavior now has injection attack prevention off.
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat-Int4", trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat-Int4",device_map="auto",trust_remote_code=True
).eval()
response, history = model.chat(tokenizer, "生成用C++将字符串倒序的代码", history=None)
print(response)

生成结果如下:

以下是C++中将字符串逆序的示例代码:#include <iostream>
#include <string>int main() {std::string str = "Hello, World!";std::string reversedStr = str;std::reverse(reversedStr.begin(), reversedStr.end());std::cout << reversedStr << std::endl;return 0;
}首先,我们定义了一个包含字符串的变量 `str`。然后,我们定义了一个空字符串变量 `reversedStr`,用于存储逆序后的字符串。接下来,我们使用 `std::reverse()` 函数将 `str` 中的字符逆序。该函数需要一个迭代器范围作为参数,表示要逆序的字符序列。在这里,我们使用 `str.begin()` 和 `str.end()` 获取字符串的起始和结束迭代器,然后将它们传递给 `std::reverse()` 函数。最后,我们输出逆序后的字符串。

我是在3060 GPU上运行成功的。

下面我们继续讲解通义千问7B的源代码。

通义千问7b的全连接网络

除了使用了silu激活函数之外,其他就是基本的全连接网络了。

class QWenMLP(nn.Module):def __init__(self, config):super().__init__()self.w1 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)self.w2 = nn.Linear(config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias)ff_dim_in = config.intermediate_size // 2self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)def forward(self, hidden_states):a1 = self.w1(hidden_states)a2 = self.w2(hidden_states)intermediate_parallel = a1 * F.silu(a2)output = self.c_proj(intermediate_parallel)return output

SiLU 函数是一种神经网络中的激活函数,全称是 Sigmoid Linear Unit, 也被称为 Swish 函数。它由 Google Brain 在 2017 年提出,是一种非线性激活函数,能够有效地对神经网络的输入进行非线性变换。

SiLU 函数的定义如下:

f(x) = x * sigmoid(x)

其中,sigmoid 函数是 Sigmoid 函数,定义如下:

sigmoid(x) = 1 / (1 + exp(-x))

SiLU 函数的特点如下:

  • 正数区域内,SiLU 函数的输出与 ReLU 函数的输出相同。
  • 在负数区域内,SiLU 函数的输出与 sigmoid 函数的输出相同。
  • SiLU 函数在整个定义域内都是可微的,这使得在反向传播过程中的梯度计算更加稳定。
  • SiLU函数不是单调递增的,而是在x≈−1.28时达到全局最小值−0.28,这可以起到一个隐式正则化的作用,抑制过大的权重

Transformer块

下面我们将RMSNorm,QWenAttention和QWenMLP三者搭建成QWenBlock,就类似于LLaMA中的TransformerBlock:

Client QWenBlock RMSNorm 1 QWenAttention RMSNorm 2 QWenMLP forward(hidden_states, ...) ln_1(hidden_states) Return layernorm_output attn(layernorm_output, ...) Return attn_outputs Split attn_output and other outputs Calculate layernorm_input ln_2(layernorm_input) Return layernorm_output mlp(layernorm_output) Return mlp_output Calculate hidden_states Prepare outputs with cache Prepare outputs without cache alt [use_cache is True] [use_cache is False] Return outputs Client QWenBlock RMSNorm 1 QWenAttention RMSNorm 2 QWenMLP
class QWenBlock(nn.Module):def __init__(self, config):super().__init__()hidden_size = config.hidden_sizeself.bf16 = config.bf16self.ln_1 = RMSNorm(hidden_size,eps=config.layer_norm_epsilon,)self.attn = QWenAttention(config)self.ln_2 = RMSNorm(hidden_size,eps=config.layer_norm_epsilon,)self.mlp = QWenMLP(config)def forward(self,hidden_states: Optional[Tuple[torch.FloatTensor]],rotary_pos_emb: Optional[List[torch.Tensor]] = None,registered_causal_mask: Optional[torch.Tensor] = None,layer_past: Optional[Tuple[torch.Tensor]] = None,attention_mask: Optional[torch.FloatTensor] = None,head_mask: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = False,output_attentions: Optional[bool] = False,):layernorm_output = self.ln_1(hidden_states)attn_outputs = self.attn(layernorm_output,rotary_pos_emb,registered_causal_mask=registered_causal_mask,layer_past=layer_past,attention_mask=attention_mask,head_mask=head_mask,use_cache=use_cache,output_attentions=output_attentions,)attn_output = attn_outputs[0]outputs = attn_outputs[1:]residual = hidden_stateslayernorm_input = attn_output + residuallayernorm_output = self.ln_2(layernorm_input)residual = layernorm_inputmlp_output = self.mlp(layernorm_output)hidden_states = residual + mlp_outputif use_cache:outputs = (hidden_states,) + outputselse:outputs = (hidden_states,) + outputs[1:]return outputs

这一模块主要就是将一些参数传递给上节我们介绍过的QWenAttention:

  • hidden_states:一个可选的元组,包含了上一层的输出张量,形状为(batch_size, sequence_length, hidden_size)。
  • rotary_pos_emb:一个可选的列表,包含了旋转位置编码张量,形状为(batch_size, sequence_length, hidden_size)。
  • registered_causal_mask:一个可选的张量,用于注册因果掩码,防止模型看到未来的信息。形状为(batch_size, sequence_length, sequence_length)。
  • layer_past:一个可选的元组,包含了上一层的注意力键值对张量,用于实现缓存机制,加速生成过程。形状为(2, batch_size, num_heads, sequence_length, head_dim)。
  • attention_mask:一个可选的浮点张量,用于对输入序列进行掩码,忽略无效的位置或填充部分。形状为(batch_size, sequence_length)或(batch_size, 1, 1, sequence_length)。
  • head_mask:一个可选的浮点张量,用于对注意力头进行掩码,随机删除一些头以增加模型的鲁棒性。形状为(num_heads,)或(1, 1, num_heads, 1)。
  • encoder_hidden_states:一个可选的张量,用于实现编码器-解码器结构时,传递编码器的输出给解码器。形状为(batch_size, encoder_sequence_length, hidden_size)。
  • encoder_attention_mask:一个可选的浮点张量,用于实现编码器-解码器结构时,对编码器输出进行掩码。形状为(batch_size, encoder_sequence_length)或(batch_size, 1, 1, encoder_sequence_length)。
  • use_cache:一个可选的布尔值,用于指示是否使用缓存机制。
  • output_attentions:一个可选的布尔值,用于指示是否输出注意力权重张量。

RMSNorm

RMSNorm我们已经讲过多次的,这里就不多介绍了:

class RMSNorm(torch.nn.Module):def __init__(self, dim: int, eps: float = 1e-6):super().__init__()self.eps = epsself.weight = nn.Parameter(torch.ones(dim))def _norm(self, x):return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)def forward(self, x):if rms_norm is not None and x.is_cuda:return rms_norm(x, self.weight, self.eps)else:output = self._norm(x.float()).type_as(x)return output * self.weight

位置编码

还记得讲百川模型代码时我们遇到的einsum吗?在千问的代码里我们会再次遇到这样的爱因斯坦风格,这次我们用到的是一个库einops。

在einops的加持下,我们可以将维度变换的操作变得更有可读性:

            from einops import rearrangeemb = rearrange(emb, "n d -> 1 n 1 d")

rearrange函数可以根据字符串表达式来重新排列张量维度。

这里的"n d -> 1 n 1 d"表示:

  • 从(n, d)形状
  • 重新排列为(1, n, 1, d)形状
    也就是在emb张量的维度1(n个向量)前面增加两维,变成1和1。

其余的还是使用cos和sin函数作cache:

class RotaryEmbedding(torch.nn.Module):def __init__(self, dim, base=10000):super().__init__()self.dim = dimself.base = baseself.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))if importlib.util.find_spec("einops") is None:raise RuntimeError("einops is required for Rotary Embedding")self._rotary_pos_emb_cache = Noneself._seq_len_cached = 0self._ntk_alpha_cached = 1.0def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):seqlen = max_seq_len + offsetif seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))self.inv_freq = 1.0 / (base** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()/ self.dim))self._seq_len_cached = max(2 * seqlen, 16)self._ntk_alpha_cached = ntk_alphaseq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)emb = torch.cat((freqs, freqs), dim=-1)from einops import rearrangeemb = rearrange(emb, "n d -> 1 n 1 d")cos, sin = emb.cos(), emb.sin()self._rotary_pos_emb_cache = [cos, sin]def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)cos, sin = self._rotary_pos_emb_cachereturn [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]

千问7B的旋转函数也是用einops.rearrange来实现的:

def _rotate_half(x):from einops import rearrangex = rearrange(x, "... (j d) -> ... j d", j=2)x1, x2 = x.unbind(dim=-2)return torch.cat((-x2, x1), dim=-1)

最后是apply_rotary_pos_emb函数,作用是将旋转位置编码应用到输入张量t上。

def apply_rotary_pos_emb(t, freqs):cos, sin = freqsif apply_rotary_emb_func is not None and t.is_cuda:t_ = t.float()cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]output = apply_rotary_emb_func(t_, cos, sin).type_as(t)return outputelse:rot_dim = freqs[0].shape[-1]cos, sin = freqst_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]t_ = t_.float()t_pass_ = t_pass_.float()t_ = (t_ * cos) + (_rotate_half(t_) * sin)return torch.cat((t_, t_pass_), dim=-1).type_as(t)

apply_rotary_pos_emb的主要步骤:

  • 从freqs中分离出cos和sin编码。
  • 如果CUDA环境且有apply_rotary_emb_func实现,直接调用该函数进行优化的旋转编码。
  • 否则,手动实现旋转编码:
  • 将t切分为要编码部分t_和不编码部分t_pass_。
  • 计算旋转编码后的t_。
  • 将编码后的t_和未编码的t_pass_拼接。
  • 返回拼接后的结果。

这样,当有优化实现时直接调用,否则用Python实现旋转位置编码。

旋转位置编码的作用是让模型表征更具局部性,使自注意力更聚焦在关键区域。这通常能提升长序列建模的性能。

通义千问的Transformer模型

tongyi

class QWenModel(QWenPreTrainedModel):_keys_to_ignore_on_load_missing = ["attn.masked_bias"]def __init__(self, config):super().__init__(config)self.vocab_size = config.vocab_sizeself.num_hidden_layers = config.num_hidden_layersself.embed_dim = config.hidden_sizeself.gradient_checkpointing = Falseself.use_dynamic_ntk = config.use_dynamic_ntkself.seq_length = config.seq_lengthself.wte = nn.Embedding(self.vocab_size, self.embed_dim)self.drop = nn.Dropout(config.emb_dropout_prob)if config.rotary_pct == 1.0:self.rotary_ndims = Noneelse:assert config.rotary_pct < 1self.rotary_ndims = int(config.kv_channels * config.rotary_pct)dim = (self.rotary_ndimsif self.rotary_ndims is not Noneelse config.kv_channels)self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)self.use_flash_attn = config.use_flash_attnself.is_fp32 = not (config.bf16 or config.fp16)if (self.use_flash_attnand flash_attn_unpadded_func is not Noneand not self.is_fp32):self.registered_causal_mask = Noneelse:max_positions = config.max_position_embeddingsself.register_buffer("registered_causal_mask",torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(1, 1, max_positions, max_positions),persistent=False,)self.h = nn.ModuleList([QWenBlock(config)for i in range(config.num_hidden_layers)])self.ln_f = RMSNorm(self.embed_dim,eps=config.layer_norm_epsilon,)self.post_init()

初始化的部分还是将之前介绍过的各模块组合在一起。

下面是虽然大但是主要是例行公事和错误判断的forward:

    def forward(self,input_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,):output_attentions = (output_attentionsif output_attentions is not Noneelse self.config.output_attentions)output_hidden_states = (output_hidden_statesif output_hidden_states is not Noneelse self.config.output_hidden_states)use_cache = use_cache if use_cache is not None else self.config.use_cachereturn_dict = (return_dict if return_dict is not None else self.config.use_return_dict)if input_ids is not None and inputs_embeds is not None:raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")elif input_ids is not None:input_shape = input_ids.size()input_ids = input_ids.view(-1, input_shape[-1])batch_size = input_ids.shape[0]elif inputs_embeds is not None:input_shape = inputs_embeds.size()[:-1]batch_size = inputs_embeds.shape[0]else:raise ValueError("You have to specify either input_ids or inputs_embeds")device = input_ids.device if input_ids is not None else inputs_embeds.deviceif token_type_ids is not None:token_type_ids = token_type_ids.view(-1, input_shape[-1])if position_ids is not None:position_ids = position_ids.view(-1, input_shape[-1])if past_key_values is None:past_length = 0past_key_values = tuple([None] * len(self.h))else:past_length = past_key_values[0][0].size(-2)if position_ids is None:position_ids = torch.arange(past_length,input_shape[-1] + past_length,dtype=torch.long,device=device,)position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])if attention_mask is not None:if batch_size <= 0:raise ValueError("batch_size has to be defined and > 0")attention_mask = attention_mask.view(batch_size, -1)attention_mask = attention_mask[:, None, None, :]attention_mask = attention_mask.to(dtype=self.dtype)attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).minencoder_attention_mask = Nonehead_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)if inputs_embeds is None:inputs_embeds = self.wte(input_ids)hidden_states = inputs_embedskv_seq_len = hidden_states.size()[1]if past_key_values[0] is not None:# past key values[0][0] shape: bs * seq_len * head_num * dimkv_seq_len += past_key_values[0][0].shape[1]if (self.use_dynamic_ntkand kv_seq_len == hidden_states.size()[1]and not self.training):context_value = math.log(kv_seq_len / self.seq_length, 2) + 1ntk_alpha = 2 ** math.ceil(context_value) - 1ntk_alpha = max(ntk_alpha, 1)else:ntk_alpha = self.rotary_emb._ntk_alpha_cachedrotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)for idx in range(len(rotary_pos_emb)):rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)hidden_states = self.drop(hidden_states)output_shape = input_shape + (hidden_states.size(-1),)if self.gradient_checkpointing and self.training:if use_cache:logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")use_cache = Falsepresents = () if use_cache else Noneall_self_attentions = () if output_attentions else Noneall_hidden_states = () if output_hidden_states else Nonefor i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):if output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if self.gradient_checkpointing and self.training:def create_custom_forward(module):def custom_forward(*inputs):# None for past_key_valuereturn module(*inputs, use_cache, output_attentions)return custom_forwardoutputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block),hidden_states,rotary_pos_emb,self.registered_causal_mask,None,attention_mask,head_mask[i],encoder_hidden_states,encoder_attention_mask,)else:outputs = block(hidden_states,layer_past=layer_past,rotary_pos_emb=rotary_pos_emb,registered_causal_mask=self.registered_causal_mask,attention_mask=attention_mask,head_mask=head_mask[i],encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,)hidden_states = outputs[0]if use_cache is True:presents = presents + (outputs[1],)if output_attentions:all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)hidden_states = self.ln_f(hidden_states)hidden_states = hidden_states.view(output_shape)# Add last hidden stateif output_hidden_states:all_hidden_states = all_hidden_states + (hidden_states,)if not return_dict:return tuple(v for v in [hidden_states, presents, all_hidden_states] if v is not None)return BaseModelOutputWithPast(last_hidden_state=hidden_states,past_key_values=presents,hidden_states=all_hidden_states,attentions=all_self_attentions,)

这实现了一个标准的Transformer编码器结构,有输入处理、Encoding块循环、输出后处理三个主要部分。使用了层规范化、多头自注意力、残差连接等机制。还支持caching、checkpoints、mask等功能。

预训练模型

下面再说一下QWenModel的基类,用于设置并行训练和保存点等信息的,继承自PreTrainedModel的类:

class QWenPreTrainedModel(PreTrainedModel):config_class = QWenConfigbase_model_prefix = "transformer"is_parallelizable = Falsesupports_gradient_checkpointing = True_no_split_modules = ["QWenBlock"]def __init__(self, *inputs, **kwargs):super().__init__(*inputs, **kwargs)def _init_weights(self, module):"""Initialize the weights."""if isinstance(module, nn.Linear):module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)if module.bias is not None:module.bias.data.zero_()elif isinstance(module, nn.Embedding):module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)if module.padding_idx is not None:module.weight.data[module.padding_idx].zero_()elif isinstance(module, RMSNorm):module.weight.data.fill_(1.0)for name, p in module.named_parameters():if name == "c_proj.weight":p.data.normal_(mean=0.0,std=(self.config.initializer_range/ math.sqrt(2 * self.config.num_hidden_layers)),)def _set_gradient_checkpointing(self, module, value=False):if isinstance(module, QWenModel):module.gradient_checkpointing = value

语言模型封装

上面的QWenModel返回的BaseModelOutputWithPast,如果要做成语言模型的话,还要封装成CausalLMOutputWithPast。

class QWenLMHeadModel(QWenPreTrainedModel):_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]def __init__(self, config):super().__init__(config)assert (config.bf16 + config.fp16 + config.fp32 <= 1), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0if autoset_precision:if SUPPORT_BF16:logger.warn("The model is automatically converting to bf16 for faster inference. ""If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".")config.bf16 = Trueelif SUPPORT_FP16:logger.warn("The model is automatically converting to fp16 for faster inference. ""If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\".")config.fp16 = Trueelse:config.fp32 = Trueif config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")if config.fp32:if SUPPORT_BF16:logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")elif SUPPORT_FP16:logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")if config.use_flash_attn == "auto":if config.bf16 or config.fp16:logger.warn("Try importing flash-attention for faster inference...")config.use_flash_attn = Trueelse:config.use_flash_attn = Falseif config.use_flash_attn and config.fp32:logger.warn("Flash attention will be disabled because it does NOT support fp32.")if config.use_flash_attn:_import_flash_attn()self.transformer = QWenModel(config)self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)if config.bf16:self.transformer.bfloat16()self.lm_head.bfloat16()if config.fp16:self.transformer.half()self.lm_head.half()self.post_init()def get_output_embeddings(self):return self.lm_headdef set_output_embeddings(self, new_embeddings):self.lm_head = new_embeddingsdef prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):token_type_ids = kwargs.get("token_type_ids", None)if past_key_values:input_ids = input_ids[:, -1].unsqueeze(-1)if token_type_ids is not None:token_type_ids = token_type_ids[:, -1].unsqueeze(-1)attention_mask = kwargs.get("attention_mask", None)position_ids = kwargs.get("position_ids", None)if attention_mask is not None and position_ids is None:position_ids = attention_mask.long().cumsum(-1) - 1position_ids.masked_fill_(attention_mask == 0, 1)if past_key_values:position_ids = position_ids[:, -1].unsqueeze(-1)else:position_ids = Noneif inputs_embeds is not None and past_key_values is None:model_inputs = {"inputs_embeds": inputs_embeds}else:model_inputs = {"input_ids": input_ids}model_inputs.update({"past_key_values": past_key_values,"use_cache": kwargs.get("use_cache"),"position_ids": position_ids,"attention_mask": attention_mask,"token_type_ids": token_type_ids,})return model_inputsdef forward(self,input_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,attention_mask: Optional[torch.FloatTensor] = None,token_type_ids: Optional[torch.LongTensor] = None,position_ids: Optional[torch.LongTensor] = None,head_mask: Optional[torch.FloatTensor] = None,inputs_embeds: Optional[torch.FloatTensor] = None,encoder_hidden_states: Optional[torch.Tensor] = None,encoder_attention_mask: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, CausalLMOutputWithPast]:return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)transformer_outputs = self.transformer(input_ids,past_key_values=past_key_values,attention_mask=attention_mask,token_type_ids=token_type_ids,position_ids=position_ids,head_mask=head_mask,inputs_embeds=inputs_embeds,encoder_hidden_states=encoder_hidden_states,encoder_attention_mask=encoder_attention_mask,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = transformer_outputs[0]lm_logits = self.lm_head(hidden_states)loss = Noneif labels is not None:labels = labels.to(lm_logits.device)shift_logits = lm_logits[..., :-1, :].contiguous()shift_labels = labels[..., 1:].contiguous()loss_fct = CrossEntropyLoss()loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))if not return_dict:output = (lm_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else outputreturn CausalLMOutputWithPast(loss=loss,logits=lm_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)

在forward之外,语言模型还需要封装一个生成函数。主要也是做一些配置,然后调用父类的生成函数:

    def generate(self,inputs: Optional[torch.Tensor] = None,generation_config: Optional[GenerationConfig] = None,logits_processor: Optional[LogitsProcessorList] = None,stopping_criteria: Optional[StoppingCriteriaList] = None,prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,synced_gpus: Optional[bool] = None,assistant_model: Optional["PreTrainedModel"] = None,streamer: Optional["BaseStreamer"] = None,**kwargs,) -> Union[GenerateOutput, torch.LongTensor]:generation_config = generation_config if generation_config is not None else self.generation_config# Process stop_words_ids.stop_words_ids = kwargs.pop("stop_words_ids", None)if stop_words_ids is None and generation_config is not None:stop_words_ids = getattr(generation_config, "stop_words_ids", None)if stop_words_ids is None:stop_words_ids = getattr(generation_config, "stop_words_ids", None)if stop_words_ids is not None:stop_words_logits_processor = StopWordsLogitsProcessor(stop_words_ids=stop_words_ids,eos_token_id=generation_config.eos_token_id,)if logits_processor is None:logits_processor = LogitsProcessorList([stop_words_logits_processor])else:logits_processor.append(stop_words_logits_processor)return super().generate(inputs,generation_config=generation_config,logits_processor=logits_processor,stopping_criteria=stopping_criteria,prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,synced_gpus=synced_gpus,assistant_model=assistant_model,streamer=streamer,**kwargs,)

聊天功能封装

    def chat(self,tokenizer: PreTrainedTokenizer,query: str,history: Optional[HistoryType],system: str = "You are a helpful assistant.",append_history: bool = True,stream: Optional[bool] = _SENTINEL,stop_words_ids: Optional[List[List[int]]] = None,generation_config: Optional[GenerationConfig] = None,**kwargs,) -> Tuple[str, HistoryType]:generation_config = generation_config if generation_config is not None else self.generation_configassert stream is _SENTINEL, _ERROR_STREAM_IN_CHATassert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMATif history is None:history = []if stop_words_ids is None:stop_words_ids = []max_window_size = kwargs.get('max_window_size', None)if max_window_size is None:max_window_size = generation_config.max_window_sizeraw_text, context_tokens = make_context(tokenizer,query,history=history,system=system,max_window_size=max_window_size,chat_format=generation_config.chat_format,)stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))input_ids = torch.tensor([context_tokens]).to(self.device)outputs = self.generate(input_ids,stop_words_ids=stop_words_ids,return_dict_in_generate=False,generation_config=generation_config,**kwargs,)response = decode_tokens(outputs[0],tokenizer,raw_text_len=len(raw_text),context_length=len(context_tokens),chat_format=generation_config.chat_format,verbose=False,errors='replace')if append_history:history.append((query, response))return response, history

其主要流程如下:

True
True
True
False
True
False
True
False
Start
Define chat function with parameters
Check if stream is _SENTINEL
Check if generation_config.chat_format equals 'chatml'
Check if history is None
Assign empty list to history
Proceed with existing history
Check if stop_words_ids is None
Assign empty list to stop_words_ids
Proceed with existing stop_words_ids
Calculate max_window_size
Call make_context function
Extend stop_words_ids
Convert context_tokens to tensor
Call generate function
Call decode_tokens function
Check if append_history is True
Append query and response to history
Do not modify history
End

流式聊天封装

最后是封装成可以流式获取的函数。

其主要流程为:

  • 和chat方法类似,先做输入query的处理,组装context。
  • 计算停止词stop_words_ids。
  • 将停止词集合封装成StopWordsLogitsProcessor。
  • 将context转成input_ids作为模型输入。
  • 关键在这里,调用generate_stream方法进行流式生成。它会逐个token地生成序列,并用yield返回每个结果。
  • 在一个while循环中收集生成的token,并用decode方法转成文本。
  • 通过yield关键字返回每个解码的结果。
  • 最终形成一个生成器,可以不断获取模型生成的内容。
    def chat_stream(self,tokenizer: PreTrainedTokenizer,query: str,history: Optional[HistoryType],system: str = "You are a helpful assistant.",stop_words_ids: Optional[List[List[int]]] = None,logits_processor: Optional[LogitsProcessorList] = None,generation_config: Optional[GenerationConfig] = None,**kwargs,) -> Generator[str, Any, None]:generation_config = generation_config if generation_config is not None else self.generation_configassert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMATif history is None:history = []if stop_words_ids is None:stop_words_ids = []max_window_size = kwargs.get('max_window_size', None)if max_window_size is None:max_window_size = generation_config.max_window_sizeraw_text, context_tokens = make_context(tokenizer,query,history=history,system=system,max_window_size=max_window_size,chat_format=generation_config.chat_format,)stop_words_ids.extend(get_stop_words_ids(generation_config.chat_format, tokenizer))if stop_words_ids is not None:stop_words_logits_processor = StopWordsLogitsProcessor(stop_words_ids=stop_words_ids,eos_token_id=generation_config.eos_token_id,)if logits_processor is None:logits_processor = LogitsProcessorList([stop_words_logits_processor])else:logits_processor.append(stop_words_logits_processor)input_ids = torch.tensor([context_tokens]).to(self.device)from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfigself.__class__.generate_stream = NewGenerationMixin.generateself.__class__.sample_stream = NewGenerationMixin.sample_streamstream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)def stream_generator():outputs = []for token in self.generate_stream(input_ids,return_dict_in_generate=False,generation_config=stream_config,logits_processor=logits_processor,seed=-1,**kwargs):outputs.append(token.item())yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')return stream_generator()

小结

这节我们终于介绍完了千问7b的模型的代码。凡是讲源码的肯定会遇到大量细节,这些细节也未必是值得花太多精力去抠的,但是原汁原味的代码还是能更精确地表达功能的真实含义。
后面我们还会将模型实现抽象一下,做更系统化的讲解便于初学者理解。对于从业的同学,因为你们面对的就是这些细节,所以先熟悉起来吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/118106.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无涯教程-Android - RadioGroup函数

RadioGroup类用于单选按钮集。 如果我们选中属于某个单选按钮组的一个单选按钮,它将自动取消选中同一组中以前选中的任何单选按钮。 RadioGroup属性 以下是与RadioGroup控制相关的重要属性。您可以查看Android官方文档以获取属性的完整列表以及可以在运行时更改这些属性的相关…

day-07 I/O复用(select)

一.I/O复用 &#xff08;一&#xff09;基于I/O复用的服务器端 1.多进程服务器 每次服务都需要创建一个进程&#xff0c;需要大量的运算和内存空间 2.复用 只需创建一个进程。 3.复用技术在服务器端的应用 &#xff08;二&#xff09;select函数实现服务器端 &#xff08;…

03-基础例程3

基础例程3 01、外部中断 ESP32的外部中断有上升沿、下降沿、低电平、高电平触发模式。 实验目的 使用外部中断功能实现按键控制LED的亮灭 按键按下为0。【即下降沿】 * 接线说明&#xff1a;按键模块-->ESP32 IO* (K1-K4)-->(14,27,26,25)* * …

软件测试Day6|接口测试

学习流程 接口测试流程 需求分析和评审–接口文档分析–编写测试用例–测试用例设计及评审–测试脚本构建–执行测试用例–缺陷管理和回归–测试报告和总结计网基础&#xff08;URL、请求、响应&#xff09; 接口文档解析 拿到一个项目接口之后&#xff0c;先测试业务接口还是…

抽象轻松的C语言

#include <stdio.h> /* 预处理指令*/ /* 函数 */ int main() {int log 3.14;printf("hello word * %d\n easy", log);getchar();/* 获取键盘输入的字母&#xff0c;在这个程序中的作用是防止程序瞬间关闭 */return 0; } 上一篇说过&#xff0c;C程序是C语言的…

21.4 CSS 盒子模型

1. 边框样式 border-style属性: 指定元素的边框样式.常用属性值: - none: 无边框(默认值). - solid: 实线边框. - dotted: 点状边框. - dashed: 虚线边框. - double: 双线边框. - groove: 凹槽状边框. - ridge: 脊状边框. - inset: 内阴影边框. - outset: 外阴影边框.这些值可…

「MySQL-02」数据库的操纵、备份、还原和编码规则

目录 一、库操作 1. 创建数据库 2. 查看所有数据库 3. 删除数据库 4. 修改数据库 5. 进入一个数据库 二、查看和设置数据库的编码规则 1. MySQL的两个编码规则&#xff1a;字符集和校验规则 2. 查看MySQL当前使用的字符集以及校验规则 3. 查看MySQL支持的所有字符集 4. 查看MyS…

肖sir__linux详解__002(系统命令)

linux系统命令 1、df 查看磁盘使用情况 &#xff08;1&#xff09;df 查看磁盘使用情况&#xff08;按kb单位显示&#xff09; &#xff08;2&#xff09;df -h 按单位显示磁盘使用情况 2、top 实时查看动态进程 &#xff08;1&#xff09;top 详解&#xff1a; 第一行&…

为什么要学习C++

操作系统历史 UINX操作系统诞生之初是用汇编语言编写的。随着UNIX的发展&#xff0c;汇编语言的开发效率成为一个瓶颈。寻找新的高效开发语言成为UNIX开发者需要解决的问题。当时BCPL语言成为了当时的选择之一。Ken Thomposn对BCPL进行简化得到了B语言。但是B语言不是直接生成…

【AWS实验】 配置中转网关及对等连接

文章目录 实验概览目标实验环境任务 1&#xff1a;查看网络拓扑并创建基准任务 2&#xff1a;创建中转网关任务 3&#xff1a;创建中转网关挂载任务 4&#xff1a;创建中转网关路由表任务 4.1&#xff1a;创建路由表关联任务 4.2&#xff1a;创建路由传播 任务 5&#xff1a;更…

Android JNI系列详解之ndk-build工具的使用

一、Android项目中使用ndk-build工具编译库文件 之前介绍过CMake编译工具的使用&#xff0c;今天介绍一种ndk自带的编译工具ndk-build的使用。 ndk-build目前主要有两种配置使用方式&#xff1a; 如上图所示&#xff0c;第一种方式是Android.mkApplication.mkgradle的方式生成…

SpringBoot初级开发--服务请求(GET/POST)所有参数的记录管理(8)

服务端在定位错误的时候&#xff0c;有时候要还原现场&#xff0c;这就要把当时的所有入参参数都能记录下来&#xff0c;GET还好说&#xff0c;基本NGINX都会记录。但是POST的请求参数基本不会被记录&#xff0c;这就需要我们通过一些小技巧来记录这些参数&#xff0c;放入日志…

C++ struct 笔记(超级详细)

今日碎碎念&#xff1a;我在学C语言时经常用到结构体struct&#xff0c;之后在写C程序时遇到在struct中定义构造函数和成员函数的情况&#xff0c;这在c语言中是从未遇到过的&#xff0c;觉得奇怪&#xff0c;想到之前并没有真正系统学习C里的struct&#xff0c;有必要今天详细…

企业架构LNMP学习笔记8

1、 运维人员需要考虑安全性、稳定性。 安装&#xff1a; 解压进入到目录&#xff1a; shell > tar zxf php-7.2.12.tar.gz shell > cd php-7.2.12 安装依赖软件&#xff1a; yum -y install libxml2-devel libjpeg-devel libpng-devel freetype-devel curl-devel op…

uniapp 微信小程序 获取用户头像和昵称

一、背景 自2022年10月25日后&#xff0c;小程序 wx.getUserProfile 接口 被收回&#xff0c;通过 wx.getUserInfo 接口获取用户头像将统一返回默认灰色头像&#xff0c;昵称将统一返回 “微信用户”。如需获取用户头像昵称&#xff0c;可以手动获取&#xff0c;具体步骤&…

Java单元测试及常用语句 | 京东物流技术团队

1 前言 编写Java单元测试用例&#xff0c;即把一段复杂的代码拆解成一系列简单的单元测试用例&#xff0c;并且无需启动服务&#xff0c;在短时间内测试代码中的处理逻辑。写好Java单元测试用例&#xff0c;其实就是把“复杂问题简单化&#xff0c;建单问题深入化“。在编写的…

Shell脚本练习——系统应用相关

显示系统信息 [rootwenzi data]#cat systemInfo.sh #/bin/bash RED"\E[1;31m" GREEN"\E[1;32m" END"\E[0m" echo -e "$GREEN----------------------Host systeminfo--------------------$END" echo -e "HOSTNAME: $REDho…

没有 JavaScript 计时器的自动播放轮播 - CSS 动画

先看效果&#xff1a; 再看代码&#xff08;查看更多&#xff09;&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>计时器</title><style>* {padding: 0;margin: 0;box-siz…

WorkManager的基本使用

目录 一、WorkManager概述1. WorkManager的作用&#xff1a;2. WorkManager的各个角色 二、依赖库的导入三、WorkManager几种基本使用1. 单一任务的执行2. 数据 互相传递3. 多个任务 顺序执行4. 重复执行后台任务5. 约束条件6. 证明 app被杀掉之后&#xff0c;还在后台执行 四、…