部分代码:
# CNN-Transformer
class CNNTransformerEncoder(nn.Module):def __init__(self, input_features, transformer_encoder_heads,embedding_features, cnn_kernel_size, dim_feedforward_enc, n_encoder_layer):super(CNNTransformerEncoder, self).__init__()# input: [batch_size, input_features, input_seq_len]# output: [batch_size, embedding_features, output_len(related to kernel_size, padding and stride)]self.cnn_embedding = nn.Conv1d(input_features, embedding_features, cnn_kernel_size) # CNN部分self.position_embedding = PositionalEncoder(d_model=embedding_features) # 位置编码transformer_encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_features,nhead=transformer_encoder_heads,dim_feedforward=dim_feedforward_enc,activation='gelu') # transformer编码器self.transformer_encoder = nn.TransformerEncoder(transformer_encoder_layer, num_layers=n_encoder_layer)def forward(self, input_seq):cnn_embedding_results = self.cnn_embedding(input_seq) # 输入经过CNNembediing_with_position = self.position_embedding(cnn_embedding_results.permute((0, 2, 1))) # 进行位置编码encoder_res = self.transformer_encoder(embediing_with_position.permute((1, 0, 2))) # 通过transformer encoderreturn encoder_res
项目截图:
数据:
测试集预测对比:
#完整代码
https://mbd.pub/o/works/592982