文章目录
- 1. num_heads
- 2. pytorch源码演算
1. num_heads
将矩阵的最后一维度进行按照num_heads的方式进行切割矩阵,具体表示如下:
2. pytorch源码演算
- pytorch 代码
import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0batch_size = 2seq_len = 4model_dim = 6num_heads = 3mat_total = batch_size * seq_len * model_dimmat1 = torch.arange(mat_total).reshape((batch_size, seq_len, model_dim))print(f"mat1=\n{mat1}")head_dim = model_dim // num_headsmat2 = mat1.reshape((batch_size, seq_len, num_heads, head_dim))print(f"mat2=\n{mat2}")mat3 = mat2.transpose(1, 2)print(f"mat3=\n{mat3}")mat4 = mat3.reshape((batch_size*num_heads,seq_len,head_dim))print(f"mat1.shape=\n{mat1.shape}")print(f"mat1=\n{mat1}")print(f"mat4.shape=\n{mat4.shape}")print(f"mat4=\n{mat4}")
- 结果:
mat1=
tensor([[[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11],[12, 13, 14, 15, 16, 17],[18, 19, 20, 21, 22, 23]],[[24, 25, 26, 27, 28, 29],[30, 31, 32, 33, 34, 35],[36, 37, 38, 39, 40, 41],[42, 43, 44, 45, 46, 47]]])
mat2=
tensor([[[[ 0, 1],[ 2, 3],[ 4, 5]],[[ 6, 7],[ 8, 9],[10, 11]],[[12, 13],[14, 15],[16, 17]],[[18, 19],[20, 21],[22, 23]]],[[[24, 25],[26, 27],[28, 29]],[[30, 31],[32, 33],[34, 35]],[[36, 37],[38, 39],[40, 41]],[[42, 43],[44, 45],[46, 47]]]])
mat3=
tensor([[[[ 0, 1],[ 6, 7],[12, 13],[18, 19]],[[ 2, 3],[ 8, 9],[14, 15],[20, 21]],[[ 4, 5],[10, 11],[16, 17],[22, 23]]],[[[24, 25],[30, 31],[36, 37],[42, 43]],[[26, 27],[32, 33],[38, 39],[44, 45]],[[28, 29],[34, 35],[40, 41],[46, 47]]]])
mat1.shape=
torch.Size([2, 4, 6])
mat1=
tensor([[[ 0, 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10, 11],[12, 13, 14, 15, 16, 17],[18, 19, 20, 21, 22, 23]],[[24, 25, 26, 27, 28, 29],[30, 31, 32, 33, 34, 35],[36, 37, 38, 39, 40, 41],[42, 43, 44, 45, 46, 47]]])
mat4.shape=
torch.Size([6, 4, 2])
mat4=
tensor([[[ 0, 1],[ 6, 7],[12, 13],[18, 19]],[[ 2, 3],[ 8, 9],[14, 15],[20, 21]],[[ 4, 5],[10, 11],[16, 17],[22, 23]],[[24, 25],[30, 31],[36, 37],[42, 43]],[[26, 27],[32, 33],[38, 39],[44, 45]],[[28, 29],[34, 35],[40, 41],[46, 47]]])
- 思考: 在矩阵y=Ax表示的时候,如果我们无法用Ax整体表示y的时候,我们可以通过将矩阵A的列向量进行拆分后得到A1,A2,A3,这样y=(A1,A2,A3)x表示更合理。