3. train_encoder_decoder.py

train_encoder_decoder.py

#__future__ 模块提供了一种方式,允许开发者在当前版本的 Python 中使用即将在将来版本中成为标准的功能和语法特性。此处为了确保代码同时兼容Python 2和Python 3版本中的print函数
from __future__ import print_function # 导入标准库和第三方库#os 是一个标准库模块,全称为 "operating system",用于提供与操作系统交互的功能。导入os.path模块用于处理文件和目录路径
import os.path 
#从os模块中导入了path子模块,可以直接使用path来调用os.path中的函数(上面的代码可以不用写)
from os import path #导入了sys模块,用于系统相关的参数和函数
import sys 
#导入了math模块,提供了数学运算函数
import math 
#导入了NumPy库,并使用np作为别名,NumPy是用于科学计算的基础库
import numpy as np 
#导入了Pandas库,并使用pd作为别名,Pandas是用于数据分析的强大库
import pandas as pd # 导入深度学习相关库
# keras 是一个机器学习和深度学习的库。backend 模块提供了对底层深度学习框架(如TensorFlow、Theano等)的访问接口,使得在不同的后端之间进行无缝切换变得更加容易。
import tensorflow as tf # 导入了Keras的backend模块,并使用K作为别名,用于访问后端引擎的函数
from keras import backend as K 
# Model 类在 Keras 中允许用户以函数式 API 的方式构建更为复杂的神经网络模型。通过使用 Model 类,可以自由地定义输入层、输出层和中间层,并将它们连接起来形成一个完整的模型。
from keras.models import Model # 1. LSTM (Long Short-Term Memory) 和 GRU (Gated Recurrent Unit) 都是循环神经网络 (RNN) 的变体,可以用来学习长期依赖关系,用于处理序列数据。
# 2. 在处理序列数据时,经常需要将某个层(如 Dense 层)应用于序列中的每一个时间步。TimeDistributed 可以将这样的层包装起来,使其能够处理整个序列。
# 3. 在函数式 API 中,可以使用 Input 来定义模型的输入节点,指定输入的形状和数据类型。
# 4. 在神经网络中,Dense 层是最基本的层之一,每个输入节点都与输出节点相连,用于学习数据中的非线性关系。
# 5. RepeatVector接受一个 2D 张量作为输入,并重复其内容 n 次生成一个3D张量,用于序列数据处理中的某些操作,例如将上下文向量重复多次以与每个时间步相关联。
from keras.layers import LSTM, GRU, TimeDistributed, Input, Dense, RepeatVector # 1. CSVLogger 是一个回调函数,用于将每个训练周期的性能指标(如损失和指标值)记录到 CSV 文件中。训练完成后,可以使用记录的数据进行分析和可视化,帮助了解模型在训练过程中的表现
# 2. EarlyStopping 是一个回调函数,用于在训练过程中根据验证集的表现来提前终止训练。它监控指定的性能指标(如验证损失)并在连续若干个周期内没有改善时停止训练,防止模型过拟合。
# 3. TerminateOnNaN 是一个回调函数,用于在训练过程中检测到损失函数返回 NaN(Not a Number)时提前终止训练。这可以帮助捕捉和处理训练过程中出现的数值问题,避免模型继续训练无效参数
from keras.callbacks import CSVLogger, EarlyStopping, TerminateOnNaN# regularizers 用于定义正则化项,减少模型的过拟合,通过向模型的损失函数添加惩罚项来限制模型参数的大小或者复杂度。
from keras import regularizers # Adam (Adaptive Moment Estimation) 优化器是基于随机梯度下降 (Stochastic Gradient Descent, SGD) 的方法之一,但它结合了动量优化和自适应学习率的特性: 
# 1. 动量(Momentum):类似于经典的随机梯度下降中的动量项,Adam会在更新参数时考虑上一步梯度的指数加权平均值,以减少梯度更新的方差,从而加速收敛; 
# 2. 自适应学习率:Adam根据每个参数的梯度的一阶矩估计(均值)和二阶矩估计(方差)来自动调整学习率。这种自适应学习率的机制可以使得不同参数有不同的学习率,从而更有效地优化模型。
from keras.optimizers import Adam # 1. 假设有一个函数 func(a, b, c),通过 partial(func, 1) 可以创建一个新函数,相当于 func(1, b, c),其中 1 是已经固定的参数。
# 2. update_wrapper 是一个函数,用于更新后一个函数的元信息(比如文档字符串、函数名等)到前一个函数上
from functools import partial, update_wrapper 
def wrapped_partial(func, *args, **kwargs):partial_func = partial(func, *args, **kwargs)update_wrapper(partial_func, func)return partial_func# 这是一个自定义的损失函数,计算加权的均方误差(Mean Squared Error)
# y_true是真实值,y_pred是预测值,weights是权重
# axis=-1指定了在计算均值时应该沿着最内层的轴进行操作,即在每个样本或数据点上进行平均,而不是在整个批次或特征维度上进行平均
def weighted_mse(y_true, y_pred, weights):return K.mean(K.square(y_true - y_pred) * weights, axis=-1)# 这部分代码用于选择使用的GPU设备。它从命令行参数中获取一个整数值gpu,如果gpu小于3,则设置CUDA环境变量以指定使用的GPU设备
import os
gpu = int(sys.argv[-13])
if gpu < 3:os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152os.environ["CUDA_VISIBLE_DEVICES"]= "{}".format(gpu)from tensorflow.python.client import device_libprint(device_lib.list_local_devices())# 这部分代码获取了一系列命令行参数,并将它们分别赋值给变量 
# 这些参数包括dataname数据集名称、nb_batches训练的批次数量、nb_epochs训练周期数、lr学习率、penalty正则化惩罚、dr丢弃率、patience耐心(用于Early Stopping),n_hidden神经网络中隐藏层的数量,hidden_activation隐藏层激活函数
imp = sys.argv[-1]
T = sys.argv[-2]
t0 = sys.argv[-3]
dataname = sys.argv[-4] 
nb_batches = sys.argv[-5]
nb_epochs = sys.argv[-6]
lr = float(sys.argv[-7])
penalty = float(sys.argv[-8])
dr = float(sys.argv[-9])
patience = sys.argv[-10]
n_hidden = int(sys.argv[-11])
hidden_activation = sys.argv[-12]# results_directory 是一个字符串,表示将要创建的结果文件夹路径,dataname 是之前从命令行参数中获取的数据集名称。
# .format(dataname) 是字符串的格式化方法,它会将 dataname 变量的值插入到占位符 {} 的位置。
# 如果这个文件夹路径不存在,就使用 os.makedirs 函数创建它。这个路径通常用于存储训练模型的结果或者日志。
results_directory = 'results/encoder-decoder/{}'.format(dataname)
if not os.path.exists(results_directory):os.makedirs(results_directory)# 定义了一个函数 create_model,用于创建、编译和返回一个循环神经网络(RNN)模型
def create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation):""" creates, compiles and returns a RNN model @param nb_features: the number of features in the model"""# 这里定义了两个输入层:# 1. inputs 是一个形状为 (n_pre, nb_features) 的输入张量,用于模型的主输入;# 2. weights_tensor 是一个形状相同的张量,用于传递权重或其他需要的信息inputs = Input(shape=(n_pre, nb_features), name="Inputs")  weights_tensor = Input(shape=(n_pre, nb_features), name="Weights") # 编码器,这里使用了两个 LSTM 层: # lstm_1 的主要作用是将输入序列转换为一个语义上丰富的固定长度表示(即隐藏状态),并且该表示包含了输入序列的全部信息。这个固定长度的表示将作为解码器的输入,用于生成目标序列。# 1. n_hidden:指定 LSTM 层的隐藏单元数,决定了网络的记忆容量和复杂度。# 2. dropout=dr 和 recurrent_dropout=dr:分别指定了输入和循环 dropout 的比例,有助于防止过拟合。# 3. activation=hidden_activation:设置了 LSTM 单元的激活函数,这里是通过 hidden_activation 参数传递的。# 4. return_sequences=True:指定返回完整的输出序列,而不是只返回最后一个时间步的输出。这是为了将完整的输入序列信息编码成隐藏状态序列,以便后续的解码器使用。# lstm_2 是一个相同的 LSTM 层,但它只返回最后一个时间步的输出 lstm_1 = LSTM(n_hidden, dropout=dr, recurrent_dropout=dr, activation=hidden_activation, return_sequences=True, name='LSTM_1')(inputs) lstm_2 = LSTM(n_hidden, activation=hidden_activation, return_sequences=False, name='LSTM_2')(lstm_1) repeat = RepeatVector(n_post, name='Repeat')(lstm_2) # get the last output of the LSTM and repeats itgru_1 = GRU(n_hidden, activation=hidden_activation, return_sequences=True, name='Decoder')(repeat)  # Decoderoutput= TimeDistributed(Dense(output_dim, activation='linear', kernel_regularizer=regularizers.l2(penalty), name='Dense'), name='Outputs')(gru_1)model = Model([inputs, weights_tensor], output)# model.compile(optimizer=Adam(lr=lr), loss=cl) 对模型进行编译。# optimizer=Adam(lr=lr) 指定了优化器为 Adam,并设置了学习率为 lr。# loss=cl 指定了损失函数为 cl,即上面定义的加权均方误差函数。cl = wrapped_partial(weighted_mse, weights=weights_tensor)model.compile(optimizer=Adam(lr=lr), loss=cl)print(model.summary()) return modeldef train_model(model, dataX, dataY, weights, nb_epoches, nb_batches):# Prepare model checkpoints and callbacksstopping = EarlyStopping(monitor='val_loss', patience=int(patience), min_delta=0, verbose=1, mode='min', restore_best_weights=True)csv_logger = CSVLogger('results/encoder-decoder/{}/training_log_{}_{}_{}_{}_{}_{}_{}_{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), separator=',', append=False)terminate = TerminateOnNaN()# 训练过程中会生成一个 history 对象,其中包含了训练过程中的损失和指标等信息,但并没有直接输出最终的参数值history = model.fit(x=[dataX,weights], y=dataY, batch_size=nb_batches, verbose=1,epochs=nb_epoches, callbacks=[stopping,csv_logger,terminate],validation_split=0.2)def test_model():n_post = int(1)n_pre =int(t0)-1seq_len = int(T)wx = np.array(pd.read_csv("data/{}-wx-{}.csv".format(dataname,imp)))print('raw wx shape', wx.shape)  wXC = []for i in range(seq_len-n_pre-n_post):wXC.append(wx[i:i+n_pre]) wXC = np.array(wXC)print('wXC shape:', wXC.shape)x = np.array(pd.read_csv("data/{}-x-{}.csv".format(dataname,imp)))print('raw x shape', x.shape) dXC, dYC = [], []for i in range(seq_len-n_pre-n_post):dXC.append(x[i:i+n_pre])dYC.append(x[i+n_pre:i+n_pre+n_post])dataXC = np.array(dXC)dataYC = np.array(dYC)print('dataXC shape:', dataXC.shape)print('dataYC shape:', dataYC.shape)nb_features = dataXC.shape[2]output_dim = dataYC.shape[2]# create and fit the encoder-decoder networkprint('creating model...')model = create_model(n_pre, n_post, nb_features, output_dim, lr, penalty, dr, n_hidden, hidden_activation)train_model(model, dataXC, dataYC, wXC, int(nb_epochs), int(nb_batches))# now testprint('Generate predictions on full training set')preds_train = model.predict([dataXC,wXC], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_train.shape)preds_train = np.squeeze(preds_train)print('predictions shape (squeezed)=', preds_train.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-train-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_train, delimiter=",")print('Generate predictions on test set')wy = np.array(pd.read_csv("data/{}-wy-{}.csv".format(dataname,imp)))print('raw wy shape', wy.shape)  wY = []for i in range(seq_len-n_pre-n_post):wY.append(wy[i:i+n_pre]) # weights for outputswXT = np.array(wY)print('wXT shape:', wXT.shape)y = np.array(pd.read_csv("data/{}-y-{}.csv".format(dataname,imp)))print('raw y shape', y.shape)  dXT = []for i in range(seq_len-n_pre-n_post):dXT.append(y[i:i+n_pre]) # treated is inputdataXT = np.array(dXT)print('dataXT shape:', dataXT.shape)preds_test = model.predict([dataXT, wXT], batch_size=int(nb_batches), verbose=1)print('predictions shape =', preds_test.shape)preds_test = np.squeeze(preds_test)print('predictions shape (squeezed)=', preds_test.shape)print('Saving to results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv'.format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches))np.savetxt("results/encoder-decoder/{}/encoder-decoder-{}-test-{}-{}-{}-{}-{}-{}.csv".format(dataname,dataname,imp,hidden_activation,n_hidden,patience,dr,penalty,nb_batches), preds_test, delimiter=",")def main():test_model()return 1if __name__ == "__main__":main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/370196.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

dotnet ef工具使用

设置工具安装目录 dotnet tool install dotnetsay --tool-path G:\dotnet-tools安装 dotnet tool install --global dotnet-ef更新 dotnet tool update --global dotnet-ef查看版本 dotnet ef --version创建迁移文件 # 只有一个dbcontext dotnet ef migrations add init #…

微机原理与单片机 知识体系梳理

单片机笔记分享 我个人感觉单片机要记的东西很多&#xff0c;也很琐碎&#xff0c;特别是一些位、寄存器以及相关作用等&#xff0c;非常难以记忆。因此复习时将知识点整理在了一起做成思维导图&#xff0c;希望对大家有所帮助。内容不是很多&#xff0c;可能有些没覆盖全&…

PySide6 实现资源的加载:深入解析与实战案例

目录 1. 引言 2. 加载内置资源 3. 使用自定义资源文件&#xff08;.qrc&#xff09; 创建.qrc文件 编译.qrc文件 加载资源 4. 动态加载UI文件 使用Qt Designer设计UI 加载UI文件 5. 注意事项与最佳实践 6. 结论 在开发基于PySide6的桌面应用程序时&…

Qt 基础组件速学 事件过滤器

学习目标&#xff1a;理解事件过滤器 前置环境 运行环境:qt creator 4.12 学习内容和效果演示&#xff1a; Qt 提供了事件过滤器的机制,允许我们在事件到达目标对象之前对事件进行拦截和处理。这在以下情况下非常有用: 全局事件处理: 我们可以在应用程序级别安装一个事件过…

JVM(13):虚拟机性能分析和故障解决工具之Visual VM

1 Visual VM作用 是到目前为止随JDK发布的功能最强大的运行监视和故障处理程序&#xff0c;并且可以遇见在未来一段时间内都是官方主力发展的虚拟机故障处理工具。官方在VisualVM的软件说明中写上了“All-in-One”的描述字样&#xff0c;预示着他除了运行监视、故障处理外&…

Android在framework层添加自定义服务的流程

环境说明 ubuntu16.04android4.1java version “1.6.0_45”GNU Make 3.81gcc version 5.4.0 20160609 (Ubuntu 5.4.0-6ubuntu1~16.04.12) 可能有人会问&#xff0c;现在都2024了怎么还在用android4版本&#xff0c;早都过时了。确实&#xff0c;现在最新的都是Android13、And…

基于YOLOv5的人脸目标检测

本文是在之前的基于yolov5的人脸关键点检测项目上扩展来的。因为人脸目标检测的效果将直接影响到人脸关键点检测的效果&#xff0c;因此本文主要讲解利用yolov5训练人脸目标检测(关键点检测可以看我人脸关键点检测文章) 基于yolov5的人脸关键点检测&#xff1a;人脸关键点检测…

复现YOLO_ORB_SLAM3_with_pointcloud_map项目记录

文章目录 1.环境问题2.遇到的问题2.1编译问题1 monotonic_clock2.2 associate.py2.3 associate.py问题 3.运行问题 1.环境问题 首先环境大家就按照github上的指定环境安装即可 环境怎么安装网上大把的资源&#xff0c;自己去找。 2.遇到的问题 2.1编译问题1 monotonic_cloc…

Android增量更新----java版

一、背景 开发过程中&#xff0c;随着apk包越来越大&#xff0c;全量更新会使得耗时&#xff0c;同时浪费流量&#xff0c;为了节省时间&#xff0c;使用增量更新解决。网上很多文章都不是很清楚&#xff0c;没有手把手教学&#xff0c;使得很多初学者&#xff0c;摸不着头脑&a…

jmeter测试工具学习

1.双击jar包打开&#xff0c;发现那个bat打不开 2.新建plan之后编辑添加线程组 会加入500*5次请求 3.添加HTTP请求 添加字段 为了让http请求发送到不同的分片&#xff0c;要把userid随机化 4.添加监听器 5.聚合报告

Wish卖家必读:如何安全有效地进行店铺测评

Wish以其独特的商业模式和先进的技术在电商领域独树一帜。作为北美和欧洲最大的移动电商平台之一&#xff0c;Wish拥有庞大的用户基础&#xff0c;其中90%的卖家来自中国&#xff0c;这不仅显示了其在全球电商市场中的影响力&#xff0c;也反映了其对中国卖家的吸引力。 Wish平…

vxe-table合并行数据;element-plus的el-table动态合并行

文章目录 一、vxe-table合并行数据1.代码 二、使用element-plus的el-table动态合并行2.代码 注意&#xff1a;const fields 是要合并的字段 一、vxe-table合并行数据 1.代码 <vxe-tableborderresizableheight"500":scroll-y"{enabled: false}":span-m…

Ubuntu 22.04远程自动登录桌面环境

如果需要远程自动登录桌面环境&#xff0c;首先需要将Ubuntu的自动登录打开&#xff0c;在【settings】-【user】下面 然后要设置【Sharing】进行桌面共享&#xff0c;Ubuntu有自带的桌面共享功能&#xff0c;不需要另外去安装xrdp或者vnc之类的工具了 点开【Remote Desktop】…

window系统openssl开发环境搭建(VS2017)

window系统openssl开发环境搭建 VS2017 一、下载openssl二、安装openssl三、openssl项目配置3.1 配置include文件3.2 配置openssl动态库四、编写openssl测试代码五、问题总结5.1 问题 一5.2 问题二一、下载openssl https://slproweb.com/products/Win32OpenSSL.html 根据自己…

CTF实战:从入门到提升

CTF实战&#xff1a;从入门到提升 &#x1f680;前言 没有网络安全就没有国家安全&#xff0c;网络安全不仅关系到国家整体信息安全&#xff0c;也关系到民生安全。近年来&#xff0c;随着全国各行各业信息化的发展&#xff0c;网络与信息安全得到了进一步重视&#xff0c;越…

【总线】AXI4第八课时:介绍AXI的 “原子访问“ :独占访问(Exclusive Access)和锁定访问(Locked Access)

大家好,欢迎来到今天的总线学习时间!如果你对电子设计、特别是FPGA和SoC设计感兴趣&#xff0c;那你绝对不能错过我们今天的主角——AXI4总线。作为ARM公司AMBA总线家族中的佼佼者&#xff0c;AXI4以其高性能和高度可扩展性&#xff0c;成为了现代电子系统中不可或缺的通信桥梁…

力扣习题--找不同

目录 前言 题目和解析 1、找不同 2、 思路和解析 总结 前言 本系列的所有习题均来自于力扣网站LeetBook - 力扣&#xff08;LeetCode&#xff09;全球极客挚爱的技术成长平台 题目和解析 1、找不同 给定两个字符串 s 和 t &#xff0c;它们只包含小写字母。 字符串 t…

Web 基础与 HTTP 协议

Web 基础与 HTTP 协议 一、Web 基础1.1域名和 DNS域名的概念Hosts 文件DNS&#xff08;Domain Name System 域名系统&#xff09;域名注册 1.2网页与 HTML网页概述HTML 概述网站和主页Web1.0 与 Web2.0 1.3静态网页与动态网页静态网页动态网页 二、HTTP 协议1.1HTTP 协议概述1.…

跨界客户服务:拓展服务边界,创造更多价值

在当今这个日新月异的商业时代&#xff0c;跨界合作已不再是新鲜词汇&#xff0c;它如同一股强劲的东风&#xff0c;吹散了行业间的壁垒&#xff0c;为企业服务创新开辟了前所未有的广阔天地。特别是在客户服务领域&#xff0c;跨界合作正以前所未有的深度和广度&#xff0c;拓…

刷题之多数元素(leetcode)

多数元素 哈希表解法&#xff1a; class Solution { public:/*int majorityElement(vector<int>& nums) {//map记录元素出现的次数&#xff0c;遍历map&#xff0c;求出出现次数最多的元素unordered_map<int,int>map;for(int i0;i<nums.size();i){map[nu…