深度学习 - PyTorch基本流程 (代码)

直接上代码

import torch 
import matplotlib.pyplot as plt 
from torch import nn# 创建data
print("**** Create Data ****")
weight = 0.3
bias = 0.9
X = torch.arange(0,1,0.01).unsqueeze(dim = 1)
y = weight * X + bias
print(f"Number of X samples: {len(X)}")
print(f"Number of y samples: {len(y)}")
print(f"First 10 X & y sample: \n X: {X[:10]}\n y: {y[:10]}")
print("\n")# 将data拆分成training 和 testing
print("**** Splitting data ****")
train_split = int(len(X) * 0.8)
X_train = X[:train_split]
y_train = y[:train_split]
X_test = X[train_split:]
y_test = y[train_split:]
print(f"The length of X train: {len(X_train)}")
print(f"The length of y train: {len(y_train)}")
print(f"The length of X test: {len(X_test)}")
print(f"The length of y test: {len(y_test)}\n")# 显示 training 和 testing 数据
def plot_predictions(train_data = X_train,train_labels = y_train,test_data = X_test,test_labels = y_test,predictions = None):plt.figure(figsize = (10,7))plt.scatter(train_data, train_labels, c = 'b', s = 4, label = "Training data")plt.scatter(test_data, test_labels, c = 'g', label="Test data")if predictions is not None:plt.scatter(test_data, predictions, c = 'r', s = 4, label = "Predictions")plt.legend(prop = {"size": 14})
plot_predictions()# 创建线性回归
print("**** Create PyTorch linear regression model by subclassing nn.Module ****")
class LinearRegressionModel(nn.Module):def __init__(self):super().__init__()self.weight = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))self.bias = nn.Parameter(data = torch.randn(1,requires_grad = True,dtype = torch.float))def forward(self, x):return self.weight * x + self.biastorch.manual_seed(42)
model_1 = LinearRegressionModel()
print(model_1)
print(model_1.state_dict())
print("\n")# 初始化模型并放到目标机里
print("*** Instantiate the model ***")
print(list(model_1.parameters()))
print("\\n")# 创建一个loss函数并优化
print("*** Create and Loss function and optimizer ***")
loss_fn = nn.L1Loss()
optimizer = torch.optim.SGD(params = model_1.parameters(),lr = 0.01)
print(f"loss_fn: {loss_fn}")
print(f"optimizer: {optimizer}\n")# 训练
print("*** Training Loop ***")
torch.manual_seed(42)
epochs = 300
for epoch in range(epochs):# 将模型加载到训练模型里model_1.train()# 做 Forwardy_pred = model_1(X_train)# 计算 Lossloss = loss_fn(y_pred, y_train)# 零梯度optimizer.zero_grad()# 反向传播loss.backward()# 步骤优化optimizer.step()### 做测试if epoch % 20 == 0:# 将模型放到评估模型并设置上下文model_1.eval()with torch.inference_mode():# 做 Forwardy_preds = model_1(X_test)# 计算测试 losstest_loss = loss_fn(y_preds, y_test)# 输出测试结果print(f"Epoch: {epoch} | Train loss: {loss:.3f} | Test loss: {test_loss:.3f}")# 在测试集上对训练模型做预测
print("\n")
print("*** Make predictions with the trained model on the test data. ***")
model_1.eval()
with torch.inference_mode():y_preds = model_1(X_test)
print(f"y_preds:\n {y_preds}")
## 画图
plot_predictions(predictions = y_preds) # 保存训练好的模型
print("\n")
print("*** Save the trained model ***")
from pathlib import Path 
## 创建模型的文件夹
MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents = True, exist_ok = True)
## 创建模型的位置
MODEL_NAME = "trained model"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME 
## 保存模型到刚创建好的文件夹
print(f"Saving model to {MODEL_SAVE_PATH}")
torch.save(obj = model_1.state_dict(), f = MODEL_SAVE_PATH)
## 创建模型的新类型
loaded_model = LinearRegressionModel()
loaded_model.load_state_dict(torch.load(f = MODEL_SAVE_PATH))
## 做预测,并跟之前的做预测
y_preds_new = loaded_model(X_test)
print(y_preds == y_preds_new)

结果如下

**** Create Data ****
Number of X samples: 100
Number of y samples: 100
First 10 X & y sample: X: tensor([[0.0000],[0.0100],[0.0200],[0.0300],[0.0400],[0.0500],[0.0600],[0.0700],[0.0800],[0.0900]])y: tensor([[0.9000],[0.9030],[0.9060],[0.9090],[0.9120],[0.9150],[0.9180],[0.9210],[0.9240],[0.9270]])**** Splitting data ****
The length of X train: 80
The length of y train: 80
The length of X test: 20
The length of y test: 20**** Create PyTorch linear regression model by subclassing nn.Module ****
LinearRegressionModel()
OrderedDict([('weight', tensor([0.3367])), ('bias', tensor([0.1288]))])*** Instantiate the model ***
[Parameter containing:
tensor([0.3367], requires_grad=True), Parameter containing:
tensor([0.1288], requires_grad=True)]*** Create and Loss function and optimizer ***
loss_fn: L1Loss()
optimizer: SGD (
Parameter Group 0dampening: 0differentiable: Falseforeach: Nonelr: 0.01maximize: Falsemomentum: 0nesterov: Falseweight_decay: 0
)*** Training Loop ***
Epoch: 0 | Train loss: 0.757 | Test loss: 0.725
Epoch: 20 | Train loss: 0.525 | Test loss: 0.454
Epoch: 40 | Train loss: 0.294 | Test loss: 0.183
Epoch: 60 | Train loss: 0.077 | Test loss: 0.073
Epoch: 80 | Train loss: 0.053 | Test loss: 0.116
Epoch: 100 | Train loss: 0.046 | Test loss: 0.105
Epoch: 120 | Train loss: 0.039 | Test loss: 0.089
Epoch: 140 | Train loss: 0.032 | Test loss: 0.074
Epoch: 160 | Train loss: 0.025 | Test loss: 0.058
Epoch: 180 | Train loss: 0.018 | Test loss: 0.042
Epoch: 200 | Train loss: 0.011 | Test loss: 0.026
Epoch: 220 | Train loss: 0.004 | Test loss: 0.009
Epoch: 240 | Train loss: 0.004 | Test loss: 0.006
Epoch: 260 | Train loss: 0.004 | Test loss: 0.006
Epoch: 280 | Train loss: 0.004 | Test loss: 0.006*** Make predictions wit the trained model on the test data. ***
y_preds:tensor([[1.1464],[1.1495],[1.1525],[1.1556],[1.1587],[1.1617],[1.1648],[1.1679],[1.1709],[1.1740],[1.1771],[1.1801],[1.1832],[1.1863],[1.1893],[1.1924],[1.1955],[1.1985],[1.2016],[1.2047]])*** Save the trained model ***
Saving model to models/trained model
tensor([[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True],[True]])

第一个结果图
第二个结果图

点个赞支持一下咯~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/290385.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小狐狸JSON-RPC:钱包连接,断开连接,监听地址改变

detect-metamask 创建连接,并监听钱包切换 一、连接钱包,切换地址(监听地址切换),断开连接 使用npm安装 metamask/detect-provider在您的项目目录中: npm i metamask/detect-providerimport detectEthereu…

基于Java在线考试系统系统设计与实现(源码+部署文档)

博主介绍: ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精彩专栏 推荐订阅 👇🏻 不然下次找不到 Java项目精品实…

快速上手Spring Cloud 六:容器化与微服务化

快速上手Spring Cloud 一:Spring Cloud 简介 快速上手Spring Cloud 二:核心组件解析 快速上手Spring Cloud 三:API网关深入探索与实战应用 快速上手Spring Cloud 四:微服务治理与安全 快速上手Spring Cloud 五:Spring …

Qt_day4:2024/3/25

作业1: 完善对话框,点击登录对话框,如果账号和密码匹配,则弹出信息对话框,给出提示”登录成功“,提供一个Ok按钮,用户点击Ok后,关闭登录界面,跳转到其他界面 如果账号和…

新能源充电桩站场AI视频智能分析烟火检测方案及技术特点分析

新能源汽车充电起火的原因多种多样,涉及技术、设备、操作等多个方面。从技术层面来看,新能源汽车的电池管理系统可能存在缺陷,导致电池在充电过程中出现过热、短路等问题,从而引发火灾。在设备方面,充电桩的设计和生产…

吴恩达深度学习笔记:浅层神经网络(Shallow neural networks)3.9-3.11

目录 第一门课:神经网络和深度学习 (Neural Networks and Deep Learning)第三周:浅层神经网络(Shallow neural networks)3.9 神 经 网 络 的 梯 度 下 降 ( Gradient descent for neural networks) 第一门课:神经网络和…

安防监控视频汇聚平台EasyCVR在银河麒麟V10系统中的启动异常及解决方法

安防监控视频平台EasyCVR具备较强的兼容性,它可以支持国标GB28181、RTSP/Onvif、RTMP,以及厂家的私有协议与SDK,如:海康ehome、海康sdk、大华sdk、宇视sdk、华为sdk、萤石云sdk、乐橙sdk等。平台兼容性强,支持Windows系…

【数据结构与算法】直接插入排序和希尔排序

引言 进入了初阶数据结构的一个新的主题——排序。所谓排序,就是一串记录,按照其中的某几个或某些关键字的大小(一定的规则),递增或递减排列起来的操作。 排序的稳定性:在一定的规则下,两个值…

基于springboot实现校园周边美食探索及分享平台系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现园周边美食探索及分享平台系统演示 摘要 美食一直是与人们日常生活息息相关的产业。传统的电话订餐或者到店消费已经不能适应市场发展的需求。随着网络的迅速崛起,互联网日益成为提供信息的最佳俱渠道和逐步走向传统的流通领域,传统的…

【漏洞复现】chatgpt pictureproxy.php SSRF漏洞(CVE-2024-27564)

0x01 漏洞概述 ChatGPT pictureproxy.php接口存在服务器端请求伪造 漏洞(SSRF) ,未授权的攻击者可以通过将构建的 URL 注入 url参数来强制应用程序发出任意请求。 0x02 测绘语句 fofa: icon_hash"-1999760920" 0x03 漏洞复现 G…

OSCP靶场--Codo

OSCP靶场–Codo 考点 1.nmap扫描 ## ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.229.23 -Pn -sV -sC --min-rate 2500 Starting Nmap 7.92 ( https://nmap.org ) at 2024-03-25 05:04 EDT Nmap scan report for 192.168.229.23 Host is up (0.35s latency). Not sh…

基于龙芯2k1000 mips架构ddr调试心得(二)

1、内存控制器概述 龙芯处理器内部集成的内存控制器的设计遵守 DDR2/3 SDRAM 的行业标准(JESD79-2 和 JESD79-3)。在龙芯处理器中,所实现的所有内存读/写操作都遵守 JESD79-2B 及 JESD79-3 的规定。龙芯处理器支持最大 4 个 CS(由…

基于Hive的天气情况大数据分析系统(通过hive进行大数据分析将分析的数据通过sqoop导入到mysql,通过Django基于mysql的数据做可视化)

基于Hive的天气情况大数据分析系统(通过hive进行大数据分析将分析的数据通过sqoop导入到mysql,通过Django基于mysql的数据做可视化) Hive介绍: Hive是建立在Hadoop之上的数据仓库基础架构,它提供了类似于SQL的语言&…

让IIS支持.NET Web Api PUT和DELETE请求

前言 有很长一段时间没有使用过IIS来托管应用了,今天用IIS来托管一个比较老的.NET Fx4.6的项目。发布到线上后居然一直调用不同本地却一直是正常的,关键是POST和GET请求都是正常的,只有PUT和DELETE请求是有问题的。经过一番思考忽然想起来了I…

Spring Cloud+Spring Alibaba笔记

Spring CloudSpring Alibaba 文章目录 Spring CloudSpring AlibabaNacos服务发现配置中心 OpenFeign超时机制开启httpclient5重试机制开启日志 SeataSentinel流量控制熔断降级热点控制规则持久化集成 OpenFeign集成 Gateway MicrometerZipKinGateway路由断言过滤器 Nacos 服务…

Spring用到了哪些设计模式?

目录 Spring 框架中⽤到了哪些设计模式?工厂模式单例模式1.饿汉式,线程安全2.懒汉式,线程不安全3.懒汉式,线程安全4.双重检查锁(DCL, 即 double-checked locking)5.静态内部类6.枚举单例 代理模…

C++超市商品管理系统

一、简要介绍 1.本项目为面向对象程序设计的大作业,基于Qt creator进行开发,Qt框架版本6.4.1,编译环境MINGW 11.2.0。 2.项目结构简介:关于系统逻辑部分的代码的头文件在head文件夹中,源文件在s文件夹中。与图形界面…

算法---动态规划练习-6(地下城游戏)

地下城游戏 1. 题目解析2. 讲解算法原理3. 编写代码 1. 题目解析 题目地址:点这里 2. 讲解算法原理 首先,定义一个二维数组 dp,其中 dp[i][j] 表示从位置 (i, j) 开始到达终点时的最低健康点数。 初始化数组 dp 的边界条件: 对…

基于Echarts的超市销售可视化分析系统(数据+程序+论文)

本论文旨在研究Python技术和ECharts可视化技术在超市销售数据分析系统中的应用。本系统通过对超市销售数据进行分析和可视化展示,帮助决策层更好地了解销售情况和趋势,进而做出更有针对性的决策。本系统主要包括数据处理、数据可视化和系统测试三个模块。…