Pytorch深度学习教程_10_神经网络训练

欢迎来到《深度学习保姆教程》系列的第九篇!在前面的几篇中,我们已经介绍了python基本用法,学习了梯度、激活函数、损失函数、优化算法等,在上一个教程中我们学习了搭建神经网络的nn模块,今天我们学习如何训练神经网络。


目录

1 数据加载和预处理

(1)理解数据格式

数据加载

(2)数据清洗

(3)数据预处理

(4)数据划分

2 Loop

关键组件

3 评估指标

(1)分类指标

(2)回归指标

(3)选择合适的指标

4 模型保存和加载

(1)保存模型

保存整个模型

仅保存模型的状态字典

(2)加载模型

加载整个模型

 加载状态字典

 5 小结


1 数据加载和预处理

数据是机器学习模型的根本。你如何准备和处理数据会显著影响模型性能。

(1)理解数据格式

数据可以有多种格式:

  • CSV/Excel: 表格数据,包含行和列。
  • JSON: 键值对结构的数据。
  • 图像: 基于像素的表示。
  • 文本: 字符或单词的序列。
  • 音频/视频: 多通道的时间序列数据。
数据加载
  • : 使用像Pandas、NumPy和OpenCV这样的库进行高效的数据加载。
  • 文件格式: 处理不同的文件格式和编码。
  • 数据结构: 将数据转换为适当的数据结构(数组、张量)。
import pandas as pd
import numpy as np# 从CSV加载数据
data = pd.read_csv('data.csv')# 转换为numpy数组
data_array = data.to_numpy()

(2)数据清洗

  • 缺失值: 处理缺失数据(插补、删除)。
  • 异常值: 识别和处理异常值(移除、封顶、转换)。
  • 数据不一致: 修正错误和不一致性。
import pandas as pd# 处理缺失值
data = data.fillna(method='ffill')  # 用前一个值填充缺失值# 移除异常值
outlier_threshold = 100  # 假设阈值为100
data = data[data['column_name'] < outlier_threshold]

(3)数据预处理

  • 归一化: 将数值特征缩放到特定范围(0-1,-1到1)。
  • 标准化: 中心化和缩放特征,使其具有零均值和单位方差。
  • 特征编码: 将分类数据转换为数值格式(独热编码、标签编码)。
  • 特征提取: 从原始数据中提取相关特征。
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
data_scaled = scaler.fit_transform(data)

(4)数据划分

  • 训练集、验证集和测试集: 将数据划分为用于模型训练、评估和测试的子集。
from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

 

2 Loop

循环是机器学习的核心。它是将数据输入模型、计算误差并更新参数的迭代过程。

典型的训练循环包括以下步骤:

  • 数据加载: 从数据集中获取一批数据。
  • 前向传播: 将数据通过模型以获得预测。
  • 损失计算: 计算预测与真实值之间的差异。
  • 反向传播: 计算损失相对于模型参数的梯度。
  • 参数更新: 根据梯度使用优化器调整模型参数。
import torch# 假设你有一个模型、优化器和数据加载器
for epoch in range(num_epochs):for i, (inputs, labels) in enumerate(train_loader):# 将参数梯度清零optimizer.zero_grad()# 前向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()
关键组件
  • 轮次 (Epoch)‌: 完整遍历整个数据集一次。
  • 批量大小 (Batch size)‌: 一次处理的样本数量。
  • 优化器 (Optimizer)‌: 用于更新模型参数的算法。
  • 损失函数 (Loss function)‌: 测量预测值与真实值之间误差的函数。

 

3 评估指标

评估指标是我们衡量模型性能的标尺。它们提供了关于模型优势和劣势的见解,帮助我们了解模型在未见数据上的泛化能力。

(1)分类指标

对于分类问题,常用的指标包括:

准确率 (Accuracy)‌: 正确预测的比例。

from sklearn.metrics import accuracy_scoreaccuracy = accuracy_score(y_true, y_pred)

 精确率 (Precision)‌: 被正确预测为正类的正类预测比例

from sklearn.metrics import precision_scoreprecision = precision_score(y_true, y_pred)

召回率 (Recall)‌: 实际正类中被正确识别的比例。

from sklearn.metrics import recall_scorerecall = recall_score(y_true, y_pred)

F1分数 (F1-score)‌: 精确率和召回率的调和平均值。

from sklearn.metrics import f1_scoref1 = f1_score(y_true, y_pred)

混淆矩阵 (Confusion matrix)‌: 一个总结分类算法性能的表格。

(2)回归指标

对于回归问题,常用的指标包括:

  • 均方误差 (Mean Squared Error, MSE)‌: 预测值与实际值之间平方差的平均值。
from sklearn.metrics import mean_squared_errormse = mean_squared_error(y_true, y_pred)
  • 平均绝对误差 (Mean Absolute Error, MAE)‌: 预测值与实际值之间绝对差的平均值。
from sklearn.metrics import mean_absolute_errormae = mean_absolute_error(y_true, y_pred)
  • 决定系数 (R-squared)‌: 模型解释的因变量方差的比例。
from sklearn.metrics import r2_scorer2 = r2_score(y_true, y_pred)

(3)选择合适的指标

指标的选择取决于问题和期望的结果:

  • 不平衡数据集: 精确率、召回率和F1分数可能比准确率更有信息量。
  • 异常值: MAE可能比MSE对异常值更稳健。
  • ROC曲线和AUC: 用于评估分类模型,特别是在不平衡数据集上。
  • 对数损失 (Log loss)‌: 衡量概率分类模型的性能。
  • 自定义指标: 为特定问题创建定制的指标。

 

4 模型保存和加载

保存和加载训练好的模型对于可重现性、部署和分享你的工作至关重要。PyTorch 提供了方便的工具来实现这一目的。

(1)保存模型

PyTorch 提供了两种主要的方法来保存模型:

保存整个模型

这种方法保留了模型的架构和参数。

import torchtorch.save(model, 'model.pth')

仅保存模型的状态字典

这会保存模型的参数,允许你将它们加载到不同的模型架构中(如果兼容的话)。

torch.save(model.state_dict(), 'model_params.pth')

(2)加载模型

要加载一个已保存的模型:

加载整个模型

loaded_model = torch.load('model.pth')

 加载状态字典

model = MyModel(*args, **kwargs)  # 创建模型实例
model.load_state_dict(torch.load('model_params.pth'))

 5 小结

本篇博客快速介绍神经网络训练的基本概念,包括数据加载处理、训练轮次、评估指标及模型保存加载。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/40071.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

阳台光伏新守护者:电流传感器助力安全发电

安科瑞顾强 插即用光伏&#xff08;Plug-In Solar PV&#xff09;以其便捷的安装方式和亲民的准入标准&#xff0c;正在推动欧洲能源结构的革新性转变。根据SolarPower Europe发布的最新行业报告显示&#xff0c;预计到2025年&#xff0c;仅德国通过认证的即插即用光伏系统注册…

【工程记录】QwQ-32b 8bit量化部署教程(vLLM | 缓解复读)

文章目录 写在前面1. 环境配置2. 下载QwQ-32b 8bit量化模型3. 使用vLLM本地推理 写在前面 仅作个人学习记录用。本文记录QwQ-32b 8bit量化模型的部署的详细方法。 1. 环境配置 以下环境经测试无bug&#xff08;Deepseek R1用这个环境也能直接跑&#xff09;&#xff1a; gp…

Elasticsearch 入门

Elasticsearch 入门 1. 认识 Elasticsearch 1.1 现有查询数据存在的问题 查询效率较低 由于数据库模糊查询不走索引&#xff0c;在数据量较大的时候&#xff0c;查询性能很差。 功能单一 数据库的模糊搜索功能单一&#xff0c;匹配条件非常苛刻&#xff0c;必须恰好包含用户…

Docker镜像相关命令(Day2)

文章目录 前言一、问题描述二、相关命令1.查看镜像2.搜索镜像3.拉取镜像4.删除镜像5.镜像的详细信息6.标记镜像 三、验证与总结 前言 Docker 是一个开源的容器化平台&#xff0c;它让开发者能够将应用及其依赖打包到一个标准化的单元&#xff08;容器&#xff09;中运行。在 D…

网站服务器常见的CC攻击防御秘籍!

CC攻击对网站的运营是非常不利的&#xff0c;因此我们必须积极防范这种攻击&#xff0c;但有些站长在防范这种攻击时可能会陷入误区。让我们先了解下CC攻击&#xff01; CC攻击是什么 CC是DDoS攻击的一种&#xff0c;CC攻击是借助代理服务器生成指向受害主机的合法请求&#x…

【PICO】开发环境配置准备

Unity编辑器配置 安装Unity编辑器 安装UnityHub 安装Unity2021.3.34f1c1 添加安卓平台模块 Pico软件资源准备 资源准备地址&#xff1a;Pico Developer PICO SDK PICO Unity Integration SDK PICO Unity Integration SDK 为 PICO 基于 Unity 引擎研发的软件开发工具…

传输层安全协议 SSL/TLS 详细介绍

传输层安全性协议TLS及其前身安全套接层SSL是一种安全传输协议&#xff0c;目前TLS协议已成为互联网上保密通信的工业标准&#xff0c;在浏览器、邮箱、即时通信、VoIP等应用程序中得到广泛的应用。本文对SSL和TLS协议进行一个详细的介绍&#xff0c;以便于大家更直观的理解和认…

一文解读DeepSeek在工业制造领域的应用

引言 在当今数字化浪潮席卷全球的背景下&#xff0c;各个行业都在积极寻求创新与变革&#xff0c;工业制造领域也不例外。然而&#xff0c;传统工业制造在生产效率、质量控制、成本管理等方面面临着诸多挑战。在这一关键时期&#xff0c;人工智能技术的兴起为工业制造带来了新的…

3.Excel:快速分析

补充&#xff1a;快捷键&#xff1a;CTRLQ 一 格式化 1.数据条 2.色阶 3.开始菜单栏里面选择更多 补充&#xff1a;想知道代表什么意思&#xff1a;管理规则-编辑规则 二 表格 点击后会变成超级表&#xff0c;之前是普通表。 三 迷你图 图放在单元格里面。 补充&#xff1a;除了…

区间端点(java)(贪心问题————区间问题)

deepseek给了一种超级简单的做法 我是真的想不到 贪心的思路是 局部最优——>全局最优 这种我是真的没有想到&#xff0c;这样的好处就是后面便利的时候可以通过foreach循环直接便利qu的子元素也就是对应的某一个区间, 将一个二维数组变成一维数组&#xff0c;每一个一维…

STM32蜂鸣器播放音乐

STM32蜂鸣器播放音乐 STM32蜂鸣器播放音乐 Do, Re, Mi, Fa, 1. 功能概述 本系统基于STM32F7系列微控制器&#xff0c;实现了以下功能&#xff1a; 通过7个按键控制蜂鸣器发声&#xff0c;按键对应不同的音符。每个按键对应一个音符&#xff08;Do, Re, Mi, Fa, Sol, La, Si&a…

基于docker-compose 部署可道云资源管理器

容器编排Explorer 容器化部署MariaDB容器化部署Redis容器化部署PHP容器化部署Nginx编排部署compose服务 var code “9861ce02-1202-405b-b419-4dddd337aaa7” GitHub官网 KodExplorer 是一款网页文件管理器。它也是一个网页代码编辑器&#xff0c;可让你直接在网页浏览器中开…

【Git】--- Git远程操作 标签管理

Welcome to 9ilks Code World (๑•́ ₃ •̀๑) 个人主页: 9ilk (๑•́ ₃ •̀๑) 文章专栏&#xff1a; Git 前面我们学习的操作都是在本地仓库进行了&#xff0c;如果团队内多人协作都在本地仓库操作是不行的&#xff0c;此时需要新的解决方案 --- 远程仓库。…

Deepseek API+Python 测试用例一键生成与导出 V1.0.3

** 功能详解** 随着软件测试复杂度的不断提升,测试工程师需要更高效的方法来设计高覆盖率的测试用例。Deepseek API+Python 测试用例生成工具在 V1.0.3 版本中,新增了多个功能点,优化了提示词模板,并增强了对文档和接口测试用例的支持,极大提升了测试用例设计的智能化和易…

Axure RP9.0 教程:左侧菜单列表导航 ( 点击父级菜单,子菜单自动收缩或展开)【响应式的菜单导航】

文章目录 引言I 实现步骤添加商品管理菜单组推拉效果引言 应用场景:PC端管理后台页面,左侧菜单列表导航。 思路: 用到了动态面板的两个交互效果来实现:隐藏/显示切换、展开/收起元件向下I 实现步骤 添加商品管理菜单组 在左侧画布区域添加一个菜单栏矩形框;再添加一个商…

详细比较StringRedisTemplate和RedisTemplate的区别及使用方法,及解决融合使用方法

前言 感觉StringRedisTemplate和RedisTemplate非常的相识&#xff0c;到底有什么区别和联系呢&#xff1f;点开idea&#xff0c;打开其依赖关系&#xff0c;可以看出只需使用maven依赖包spring-boot-starter-data-redis&#xff0c;然后在service中注入StringRedisTemplate或者…

SpringSecurity——前后端分离登录认证

SpringSecurity——前后端分离登录认证的整个过程 前端&#xff1a; 使用Axios向后端发送请求 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>登录</title><script src"https://cdn…

如何用腾讯云建站做好一个多语言的建筑工程网站?海外用户访问量提升3倍!分享我的经验

作为新疆地区领先的工程建筑企业&#xff0c;我们深知在数字化浪潮中&#xff0c;一个专业、高效且具备国际视野的官方网站是企业形象与业务拓展的“门面担当”。然而&#xff0c;传统的建站流程复杂、技术门槛高、多语言适配难等问题&#xff0c;曾让我们在数字化转型中举步维…

遥控器钥匙学习---通过uds指令

1、实际报文 2、硬件配置信息 使用原gateway硬件&#xff0c;软件基于sbcm-main工程新建的一个分支。主要用于钥匙学习的指令发送。 3、后续更改 这里需要细化一下&#xff0c;为了后续方便测试 4、钥匙学习策略 可以学习2把钥匙 一次可以学习把钥匙&#xff0c;uds命令&…

QinQ项展 VLAN 空间

随着以太网技术在网络中的大量部署&#xff0c;利用 VLAN 对用户进行隔离和标识受到很大限制。因为 IEEE802.1Q 中定义的 VLAN Tag 域只有 12 个比特&#xff0c;仅能表示 4096 个 VLAN&#xff0c;无法满足城域以太网中标识大量用户的需求&#xff0c;于是 QinQ 技术应运而生。…