1. Transformer结构图
2. python
import torch
import torch. nn as nn
import torch. nn. functional as Ftorch. set_printoptions( precision= 3 , sci_mode= False ) if __name__ == "__main__" : run_code = 0 batch_size = 2 seq_length = 3 vocab_size = 4 logits = torch. randn( batch_size, seq_length, vocab_size) print ( f"logits=\n { logits} " ) logits_t = logits. transpose( - 1 , - 2 ) print ( f"logits_t=\n { logits_t} " ) label = torch. randint( 0 , vocab_size, ( batch_size, seq_length) ) print ( f"label=\n { label} " ) result_none = F. cross_entropy( logits_t, label, reduction= "none" ) print ( f"result_none=\n { result_none} " ) result_none_mean = torch. mean( result_none) result_mean = F. cross_entropy( logits_t, label) print ( f"result_mean=\n { result_mean} " ) print ( f"result_none_mean= { result_none_mean} " )
logits=
tensor( [ [ [ 0.477 , 2.017 , 1.016 , - 0.299 ] , [ - 0.189 , 0.321 , - 0.885 , 1.418 ] , [ 0.027 , - 0.606 , 0.079 , - 0.491 ] ] , [ [ 1.911 , 1.643 , - 0.327 , 0.185 ] , [ - 0.031 , - 1.463 , - 0.073 , 1.391 ] , [ - 0.710 , 0.811 , 1.521 , 0.033 ] ] ] )
logits_t=
tensor( [ [ [ 0.477 , - 0.189 , 0.027 ] , [ 2.017 , 0.321 , - 0.606 ] , [ 1.016 , - 0.885 , 0.079 ] , [ - 0.299 , 1.418 , - 0.491 ] ] , [ [ 1.911 , - 0.031 , - 0.710 ] , [ 1.643 , - 1.463 , 0.811 ] , [ - 0.327 , - 0.073 , 1.521 ] , [ 0.185 , 1.391 , 0.033 ] ] ] )
label=
tensor( [ [ 0 , 0 , 0 ] , [ 3 , 0 , 0 ] ] )
result_none=
tensor( [ [ 2.059 , 2.098 , 1.157 ] , [ 2.444 , 1.848 , 2.832 ] ] )
result_mean=
2.0730881690979004
result_none_mean= 2.0730881690979004