深度学习(13)--PyTorch搭建神经网络进行气温预测

一.搭建神经网络进行气温预测流程详解

1.1.导入所需的工具包

import numpy as np  # 矩阵计算
import pandas as pd   # 数据读取
import matplotlib.pyplot as plt  # 画图处理
import torch  # 构建神经网络
import torch.optim as optim  # 设置优化器

1.2.读取并处理数据

引入数据并查看数据的格式

# 引入数据
features = pd.read_csv('temps.csv')# 看看数据长什么样子
print(features.head())

Pandas库中的.head()函数,取数据的前n行数据,默认是取前五行数据,如上图所示。

查看数据维度

print('数据维度:', features.shape)

shape函数的功能是读取矩阵的长度,.shape直接输出数据的维度,如上图,表示该数据的维度为348行,9列。对应的也就是348个样本,9个特征。

而shape[0],shape[1]则分别返回矩阵第一维度、第二维度的长度:

# 查看数据维度
print('数据维度:', features.shape[0])
print('数据维度:', features.shape[1])

处理时间数据

# 处理时间数据
import datetime# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]

 查看处理的datas数据格式

print(dates[:3])

对特殊数据进行one-hot编码 

计算机无法识别字符串数据,所以对于字符串数据需要使用one-hot编码:

features = pd.get_dummies(features) 

# get_dummies会自动判断数据中哪一列是字符串,并自动将字符串展开。

# eg:数据中用于标注星期的字符串一共有七个,则get_dummies函数将数据展开成七列,当天是哪一天就在相应位置标1。

# 星期 一 二 三 四 五 六 七,如果是星期一则标注为:1 0 0 0 0 0 0,如果是星期三则标注为:0 0 1 0 0 0 0,如果是星期六则标注为:0 0 0 0 0 1 0

 查看one-hot编码后的数据

对标签进行处理

# 标签
labels = np.array(features['actual'])  # 获取标签:features获取actual的标签然后再转换为np.array的格式# 在特征中去掉标签
features= features.drop('actual', axis = 1)  # 去除features中的actual标签,axis表示沿着行/列去除,axis=0按行计算,axis=1按列计算# 名字单独保存一下,以备后患
feature_list = list(features.columns)  # 保存features中的columns值,也就是列# 转换成合适的格式
features = np.array(features)  # 把处理后的features数据也转换为np.array格式

标准化处理 

不同的数据取值范围不同,而机器又会认为数值大的数据较为重要,所以需要对数据进行标准化(x-μ/σ) -- μ为均值,σ为标准差。

from sklearn import preprocessing
input_features = preprocessing.StandardScaler().fit_transform(features)  # fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。

fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。

标准化处理前后的数据:

1.3.构建网络模型

构建网络

本项目构建的网络模型较为简单,只有一个隐层

# shape[0]是样本数,也就是行的数据,shape[1]是特征数,也就是列的数据
input_size = input_features.shape[1]  
hidden_size = 128
output_size = 1
batch_size = 16  # 一次迭代batch个样本
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size),  # 根据输入自动初始化权重参数和偏重值torch.nn.ReLU(),  # 激活函数 Sigmoid/Relutorch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')  # 损失函数设置:MSE均方误差
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)  
# 优化器设置:Adam,参数为网络中的所有参数以及学习率

训练网络

# 训练网络
losses = []
# 迭代1000次,epoch = 1000
for i in range(1000):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_features), batch_size):  # 循环范围为0~样本数,每次循环中间间隔batchs_sizeend = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 做一个索引是否越界的判断# 取得一个batch的数据xx = torch.tensor(input_features[start:end], dtype = torch.float, requires_grad = True)yy = torch.tensor(labels[start:end], dtype = torch.float, requires_grad = True)prediction = my_nn(xx)  # 输入值经过定义的网络运算得到预测值loss = cost(prediction, yy)  # 参数为预测值和真实值optimizer.zero_grad()  # torch的迭代过程中会累计之前的训练结果,所以在每次迭代中需要清空梯度值loss.backward(retain_graph=True)  # 反向传播optimizer.step()  # 对所有参数进行更新batch_loss.append(loss.data.numpy())# 打印损失if i % 100==0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))

预测训练结果 

x = torch.tensor(input_features, dtype = torch.float)  
# 先将数据转换为tensor格式,因为需要在网络中进行运算
predict = my_nn(x).data.numpy()  
# 在网络中运算完成中,再转换为data.numpy格式,因为后续需要进行画图处理

1.4.对结果进行画图对比

# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': dates, 'actual': labels})# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})   # predict是x经过网络训练再转换为np.array的值# 画图# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')  # 参数分别为:横轴,纵轴,曲线颜色,标签值# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')  # 参数分别为:横轴,纵轴,曲线颜色,标签值
plt.xticks(rotation=30)  # x轴参数倾斜60°
plt.legend()  # 使上述代码产生效果# 图名
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')  # x,y轴标签设置
plt.title('Actual and Predicted Values')  # 图名设置# 保存图片并展示
plt.savefig("result.png")
plt.show()

二.完整代码

import numpy as np  # 矩阵计算
import pandas as pd   # 数据读取
import matplotlib.pyplot as plt  # 画图处理
import torch  # 构建神经网络
import torch.optim as optim  # 设置优化器# 处理时间数据
import datetimefrom sklearn import preprocessing# 引入数据
features = pd.read_csv('temps.csv')# 看看数据长什么样子
# print(features.head())# 查看数据维度
# print('数据维度:', features.shape)# 分别得到年,月,日
years = features['year']
months = features['month']
days = features['day']'''
# datetime格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]print(dates[:3])
'''# 独热(one-hot)编码 -- 机器不认识字符串,需要将字符串转换为机器认识的参数
features = pd.get_dummies(features)
# get_dummies会自动判断数据中哪一列是字符串,并自动将字符串展开
# eg:数据中用于标注星期的字符串一共有七个,则get_dummies函数将数据展开成七列,当天是哪一天就在相应位置标1
# 星期 一 二 三 四 五 六 七,如果是星期一则标注为:1 0 0 0 0 0 0,如果是星期三则标注为:0 0 1 0 0 0 0,如果是星期六则标注为:0 0 0 0 0 1 0
# print(features.head(5))# 标签
labels = np.array(features['actual'])  # 获取标签:features获取actual的标签然后再转换为np.array的格式# 在特征中去掉标签
features = features.drop('actual', axis = 1)  # 去除features中的actual标签,axis表示沿着行/列去除,axis=0按行计算,axis=1按列计算# 名字单独保存一下,以备后患
feature_list = list(features.columns)  # 保存features中的columns值,也就是列# 转换成合适的格式
features = np.array(features)  # 把处理后的features数据也转换为np.array格式# print(features[0])
# 标准化处理
input_features = preprocessing.StandardScaler().fit_transform(features)
# fit_transform通过数据计算出均值和标准差,再对数据进行标准化处理变换。
# print(input_features[0])# shape[0]是样本数,也就是行的数据,shape[1]是特征数,也就是列的数据
input_size = input_features.shape[1]
hidden_size = 128
output_size = 1
batch_size = 16  # 一次迭代batch个样本
my_nn = torch.nn.Sequential(torch.nn.Linear(input_size, hidden_size),  # 根据输入自动初始化权重参数和偏重值torch.nn.ReLU(),  # 激活函数 Sigmoid/ReLUtorch.nn.Linear(hidden_size, output_size),
)
cost = torch.nn.MSELoss(reduction='mean')  # 损失函数设置:MSE均方误差
optimizer = torch.optim.Adam(my_nn.parameters(), lr=0.001)
# 优化器设置:Adam,参数为网络中的所有参数以及学习率# 训练网络
losses = []
# 迭代1000次,epoch = 1000
for i in range(1000):batch_loss = []# MINI-Batch方法来进行训练for start in range(0, len(input_features), batch_size):  # 循环范围为0~样本数,每次循环中间间隔batchs_sizeend = start + batch_size if start + batch_size < len(input_features) else len(input_features)  # 做一个索引是否越界的判断# 取得一个batch的数据xx = torch.tensor(input_features[start:end], dtype=torch.float, requires_grad=True)yy = torch.tensor(labels[start:end], dtype=torch.float, requires_grad=True)prediction = my_nn(xx)  # 输入值经过定义的网络运算得到预测值loss = cost(prediction, yy)  # 参数为预测值和真实值optimizer.zero_grad()  # torch的迭代过程中会累计之前的训练结果,所以在每次迭代中需要清空梯度值loss.backward(retain_graph=True)  # 反向传播optimizer.step()  # 对所有参数进行更新batch_loss.append(loss.data.numpy())'''# 打印损失if i % 100 == 0:losses.append(np.mean(batch_loss))print(i, np.mean(batch_loss))'''# 预测训练结果
x = torch.tensor(input_features, dtype=torch.float)
# 先将数据转换为tensor格式,因为需要在网络中进行运算
predict = my_nn(x).data.numpy()
# 在网络中运算完成中,再转换为data.numpy格式,因为后续需要进行画图处理# 转换日期格式
dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in dates]# 创建一个表格来存日期和其对应的标签数值
true_data = pd.DataFrame(data={'date': dates, 'actual': labels})# 同理,再创建一个来存日期和其对应的模型预测值
months = features[:, feature_list.index('month')]
days = features[:, feature_list.index('day')]
years = features[:, feature_list.index('year')]test_dates = [str(int(year)) + '-' + str(int(month)) + '-' + str(int(day)) for year, month, day in zip(years, months, days)]
test_dates = [datetime.datetime.strptime(date, '%Y-%m-%d') for date in test_dates]predictions_data = pd.DataFrame(data = {'date': test_dates, 'prediction': predict.reshape(-1)})   # predict是x经过网络训练再转换为np.array的值# 画图# 真实值
plt.plot(true_data['date'], true_data['actual'], 'b-', label='actual')  # 参数分别为:横轴,纵轴,曲线颜色,标签值# 预测值
plt.plot(predictions_data['date'], predictions_data['prediction'], 'ro', label='prediction')  # 参数分别为:横轴,纵轴,曲线颜色,标签值
plt.xticks(rotation=30)  # x轴参数倾斜60°
plt.legend()  # 使上述代码产生效果# 图名
plt.xlabel('Date')
plt.ylabel('Maximum Temperature (F)')  # x,y轴标签设置
plt.title('Actual and Predicted Values')  # 图名设置# 保存图片并展示
plt.savefig("result.png")
plt.show()

三.输出结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/255935.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第五篇:MySQL常见数据类型

MySQL中的数据类型有很多&#xff0c;主要分为三类:数值类型、字符串类型、日期时间类型 三个表格都在此网盘中&#xff0c;需要者可移步自取&#xff0c;如果觉得有帮助希望点个赞~ MySQL常见数据类型表 数值类型 &#xff08;注&#xff1a;decimal类型举例&#xff0c;如1…

DevOps:CI、CD、CB、CT、CD

目录 一、软件开发流程演化快速回顾 &#xff08;一&#xff09;瀑布模型 &#xff08;二&#xff09;原型模型 &#xff08;三&#xff09;螺旋模型 &#xff08;四&#xff09;增量模型 &#xff08;五&#xff09;敏捷开发 &#xff08;六&#xff09;DevOps 二、走…

【开源】SpringBoot框架开发天沐瑜伽馆管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 瑜伽课程模块2.3 课程预约模块2.4 系统公告模块2.5 课程评价模块2.6 瑜伽器械模块 三、系统设计3.1 实体类设计3.1.1 瑜伽课程3.1.2 瑜伽课程预约3.1.3 系统公告3.1.4 瑜伽课程评价 3.2 数据库设计3.2.…

【Kubernetes】kubectl top pod 异常?

目录 前言一、表象二、解决方法1、导入镜像包2、编辑yaml文件3、解决问题 三、优化改造1.修改配置文件2.检查api-server服务是否正常3.测试验证 总结 前言 各位老铁大家好&#xff0c;好久不见&#xff0c;卑微涛目前从事kubernetes相关容器工作&#xff0c;感兴趣的小伙伴相互…

【Kubernetes】在k8s1.24及以上版本基于containerd容器运行时测试pod从harbor拉取镜像

基于containerd容器运行时测试pod从harbor拉取镜像 1、安装高版本containerd2、安装docker3、登录harbor上传镜像4、从harbor拉取镜像 1、安装高版本containerd 集群中各个节点都要操作 yum remove containerd.io -y yum install containerd.io-1.6.22* -y cd /etc/containe…

《UE5_C++多人TPS完整教程》学习笔记2 ——《P3 多人游戏概念(Multiplayer Concept)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P3 多人游戏概念&#xff08;Multiplayer Concept&#xff09;》 的学习笔记&#xff0c;该系列教学视频为 Udemy 课程 《Unreal Engine 5 C Multiplayer Shooter》 的中文字幕翻译版&#xff0c;UP主&#xff08;也是译…

数据结构:并查集讲解

并查集 1.并查集原理2.并查集实现3.并查集应用4.并查集的路径压缩 1.并查集原理 在一些应用问题中&#xff0c;需要将n个不同的元素划分成一些不相交的集合。开始时&#xff0c;每个元素自成一个单元素集合&#xff0c;然后按一定的规律将归于同一组元素的集合合并。在此过程中…

分享88个鼠标特效,总有一款适合您

分享88个鼠标特效&#xff0c;总有一款适合您 88个鼠标特效下载链接&#xff1a;https://pan.baidu.com/s/1ljcxwgXGpw7baiufUGJjZA?pwd8888 提取码&#xff1a;8888 Python采集代码下载链接&#xff1a;采集代码.zip - 蓝奏云 学习知识费力气&#xff0c;收集整理更不…

5G NR 频率计算

5G中引入了频率栅格的概念&#xff0c;也就是小区中心频点和SSB的频域位置不能随意配置&#xff0c;必须满足一定规律&#xff0c;主要目的是为了UE能快速的搜索小区&#xff1b;其中三个最重要的概念是Channel raster 、synchronization raster和pointA。 1、Channel raster …

Hive正则表达式

Hive版本&#xff1a;hive-3.1.2 一、Hive的正则表达式概述 正则表达式是一种用于匹配和操作文本的强大工具&#xff0c;它是由一系列字符和特殊字符组成的模式&#xff0c;用于描述要匹配的文本模式。 Hive的正则表达式灵活使用解决HQL开发过程中的很多问题&#xff0c;本篇文…

H12-821_26

26.下列选项中,哪些路由前缀满足下面的IP-Prefix条件? A.20.0.1.0/24 B.20.0.1.0/23 C.20.0.1.0/25 D.20.0.1.0/28 答案&#xff1a;ACD 注释&#xff1a; 前缀列表可以匹配路由前缀和网络掩码。 ip ip-prefix test index 10 permit 20.0.0.0 16 greater-equal 24 less-equal…

Zephyr NRF7002 实现AppleJuice

BLE的基础知识 ble的信道和BR/EDR的信道是完全不一样的。但是范围是相同的&#xff0c;差不多也都是2.4Ghz的频道。可以简单理解为空中有40个信道0~39信道。两个设备在相同的信道里面可以进行相互通信。 而这些信道SIG又重新编号&#xff1a; 这个编号就是把37 38 39。 3个信道…

Seurat - 聚类教程 (1)

设置 Seurat 对象 在本教程[1]中&#xff0c;我们将分析 10X Genomics 免费提供的外周血单核细胞 (PBMC) 数据集。在 Illumina NextSeq 500 上对 2,700 个单细胞进行了测序。可以在此处[2]找到原始数据。 我们首先读取数据。 Read10X() 函数从 10X 读取 cellranger 管道的输出&…

Linux network namespace 访问外网以及多命名空间通信(经典容器组网 veth pair + bridge 模式认知)

写在前面 整理K8s网络相关笔记博文内容涉及 Linux network namespace 访问外网方案 Demo实际上也就是 经典容器组网 veth pair bridge 模式理解不足小伙伴帮忙指正 不必太纠结于当下&#xff0c;也不必太忧虑未来&#xff0c;当你经历过一些事情的时候&#xff0c;眼前的风景已…

数据结构第十四天(树的存储/双亲表示法)

目录 前言 概述 接口&#xff1a; 源码&#xff1a; 测试函数&#xff1a; 运行结果&#xff1a; 往期精彩内容 前言 孩子&#xff0c;一定要记得你的父母啊&#xff01;&#xff01;&#xff01; 哈哈&#xff0c;今天开始学习树结构中的双亲表示法&#xff0c;让孩…

H12-821_74

74.在某路由器上查看LSP&#xff0c;看到如下结果&#xff1a; A.发送目标地址为3.3.3.3的数据包时&#xff0c;打上标签1026&#xff0c;然后发送。 B.发送目标地址为4.4.4.4的数据包时&#xff0c;不打标签直接发送。 C.当路由器收到标签为1024的数据包&#xff0c;将把标签…

MySQL数据库⑦_复合查询+内外链接(多表/子查询)

目录 1. 回顾基本查询 2. 多表查询 2.1 笛卡尔积初步过滤 3. 自连接 4. 子查询 4.1 单行子查询 4.2 多行子查询 4.2 多列子查询 4.2 from子句中使用子查询 5. 合并查询 6. 内外链接 6.1 内连接 6.2 左外链接 6.2 右外连接 本篇完。 1. 回顾基本查询 先回顾一下…

架构整洁之道-软件架构-展示器和谦卑对象、不完全边界、层次与边界、Main组件、服务

6 软件架构 6.9 展示器和谦卑对象 在《架构整洁之道-软件架构-策略与层次、业务逻辑、尖叫的软件架构、整洁架构》有我们提到了展示器&#xff08;presenter&#xff09;&#xff0c;展示器实际上是采用谦卑对象&#xff08;humble object&#xff09;模式的一种形式&#xff…

【深度优先搜索】【树】【图论】2973. 树中每个节点放置的金币数目

作者推荐 视频算法专题 本博文涉及知识点 深度优先搜索 树 图论 分类讨论 LeetCode2973. 树中每个节点放置的金币数目 给你一棵 n 个节点的 无向 树&#xff0c;节点编号为 0 到 n - 1 &#xff0c;树的根节点在节点 0 处。同时给你一个长度为 n - 1 的二维整数数组 edges…

寒假作业

手写盗版微信登入界面 #include "mainwindow.h" #include "ui_mainwindow.h"MainWindow::MainWindow(QWidget *parent): QMainWindow(parent), ui(new Ui::MainWindow) {ui->setupUi(this);this->resize(421,575);this->setFixedSize(421,575);th…