Pytorch系列教程:模型训练的基本要点

PyTorch是一个开源的机器学习库,由于其灵活性和动态计算图而迅速流行起来。在PyTorch中训练模型是任何数据科学家或机器学习工程师的基本技能。本文将指导您完成使用PyTorch训练模型所需的基本步骤。

总体说明

模型训练流程主要包括数据准备、网络构建、优化配置及迭代训练。首先将数据划分为训练集、验证集和测试集,通过归一化和数据增强预处理后,利用DataLoader实现批量加载。接着定义包含输入层、隐藏层和输出层的神经网络结构,确保各层维度匹配数据特征。选择交叉熵损失函数衡量预测误差,并基于SGD或Adam等优化器调整参数。训练时通过前向传播输出预测,反向传播计算梯度并更新权重,结合动量和学习率控制收敛速度。完成后在测试集上无梯度验证模型性能,统计准确率等指标评估泛化能力。最终通过超参数调优(如调整学习率、网络结构)优化模型效果,形成完整的训练闭环。
在这里插入图片描述

下面针对关键步骤,结合示例分别进行说明。

步骤1:安装和设置

在我们深入研究训练模型之前,必须正确设置PyTorch。PyTorch可以使用pip轻松安装。执行如下命令安装:

pip install torch torchvision

确保你有兼容版本的Python和CUDA(如果你使用GPU支持),以获得有效的设置。

步骤2:准备数据

数据准备是至关重要的一步。PyTorch提供了torchvision等工具来简化此过程。你可能通常需要将数据集分为训练子集和测试子集。

from torchvision import datasets, transforms# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# Download and load the training data
trainset = datasets.MNIST(root='./mnist_data', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

关键说明:

  1. MNIST图像是灰度图(单通道),因此转换后张量形状为 (1, 28, 28)
  2. Normalize方式实现归一化,归一化公式:(x - mean) / std
    • 均值(mean)=(0.5):将像素值从[0,255]映射到[-1,1]
    • 标准差(std)=(0.5):配合均值使数据分布更适合神经网络

步骤3:构建模型

在设置数据之后,下一步是定义模型体系结构。一个简单的前馈神经网络可以作为一个很好的起点。

import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(28 * 28, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = x.view(-1, 28 * 28)  # Flatten the inputx = F.relu(self.fc1(x))x = self.fc2(x)return x
  • MNIST输入形状从 (batch_size, 1, 28, 28)(batch_size, 784) ,数学计算过程:

    • 第一层: 784 features → 512 neurons
      计算公式:y = W1x + b1
      激活函数:ReLU(y) = max(0, y)

    • 第二层: 512 neurons → 10 neurons

      计算公式:z = W2y + b2 输出结果直接作为分类logits(未归一化)

步骤4:定义损失函数和优化器

损失函数和优化器的选择会显著影响训练过程。对于像MNIST这样的分类任务,使用CrossEntropyLoss和SGD优化器。

import torch.optim as optimnet = Net()   #实例化模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

步骤5:训练模型

这一步包括迭代数据,将其传递到网络中,计算损失,并更新权重。下面是PyTorch中的一个简单的训练循环:

for epoch in range(5):  # loop over the dataset multiple timesrunning_loss = 0.0for inputs, labels in trainloader:# Zero the parameter gradientsoptimizer.zero_grad()# Forward + backward + optimizeoutputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()print(f'Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}')

步骤6:评估模型

最后,实现基于测试数据评估模型的技术;这有助于确保你的模型预测是有价值的。

# Load test data
testset = datasets.MNIST(root='./mnist_data', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)correct = 0
total = 0
with torch.no_grad():for inputs, labels in testloader:outputs = net(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy: {100 * correct / total}%')

该代码片段完成了从数据加载到模型评估的完整流程,是机器学习项目标准验证环节的典型实现。实际应用中可根据具体需求扩展为集成测试框架。

最后总结

模型训练的核心是让网络从数据中学习规律以最小化预测误差。流程分为数据预处理、模型定义、训练执行与评估优化三阶段。数据需标准化并分批次输入,模型结构需适配数据特征,损失函数与优化器共同决定训练方向。训练时通过前向传播生成预测,反向传播更新参数,迭代直至收敛。测试阶段验证模型泛化能力,超参数调优进一步提升性能。整个过程强调数据质量、模型设计和训练策略的协同作用,目标是构建高效稳定的预测系统。

通过遵循这些步骤并有效地利用PyTorch的强大功能,您可以训练和改进神经网络以解决各种机器学习问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/28590.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NVIDIA(英伟达) GPU 芯片架构发展史

GPU 性能的关键参数 CUDA 核心数量(个):决定了 GPU 并行处理能力,在 AI 等并行计算类业务下,CUDA 核心越多性能越好。 显存容量(GB):决定了 GPU 加载数据量的大小,在 AI…

汽车免拆诊断案例 | 2023款丰田雷凌汽油版车行驶中偶尔出现通信故障

故障现象  一辆2023款丰田雷凌汽油版车,搭载1.5 L发动机,累计行驶里程约为4700 km。车主反映,行驶中偶尔组合仪表上的发动机转速信号丢失,转向变重,且有“闯车”感,同时车辆故障警报蜂鸣器鸣响。 故障诊断…

鸿蒙与DeepSeek深度整合:构建下一代智能操作系统生态

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 https://www.captainbed.cn/north 目录 技术融合背景与价值鸿蒙分布式架构解析DeepSeek技术体系剖析核心整合架构设计智能调度系统实现…

AutoGen学习笔记系列(七)Tutorial - Managing State

这篇文章瞄准的是AutoGen框架官方教程中的 Tutorial 章节中的 Managing State 小节,主要介绍了如何对Team内的状态管理,特别是如何 保存 与 加载 状态,这对于Agent系统而言非常重要。 官网链接:https://microsoft.github.io/auto…

cenos7网络安全检查

很多网络爱好者都知道,在Windows 2000和Windows 9x的命令提示符下可使用Windows系统自带的多种命令行网络故障检测工具,比如说我们最常用的ping。但大家在具体应用时,可能对这些命令行工具的具体含义,以及命令行后面可以使用的种…

MagicDriveDiT:具有自适应控制的自动驾驶高分辨率长视频生成

24年11月来自香港中文大学、香港科技大学和华为公司的论文“MagicDriveDiT: High-Resolution Long Video Generation for Autonomous Driving with Adaptive Control”。 扩散模型的快速进步极大地改善视频合成,特别是可控视频生成,这对于自动驾驶等应用…

大模型架构记录1

整体的学习架构 一 模型构建和应用 1 训练数据 (重点) 2 模型设计 (transformer) 3 模型训练 (fine-tuning 微调) (产品经理后面可能能做) 4 benchmark (评测) 5 memory (内存)(知识图谱,向量数据库) 6 搜索技…

【Liunx专栏_3】Liunx进程概念知识点

文章目录 前言1、冯诺依曼体系结构2、操作系统2.1、系统调用 3、进程3.1、进程概念3.2、进程描述—PCB3.3、查看进程信息3.4、通过系统调用获取进程标识符3.5、通过系统调用创建子进程—fork() 4、进程状态5、僵尸进程6、孤儿进程7、进程优先级7.1、PRI和NI是什么?7…

Nacos简介、安装与使用(保姆级教程!!!)

目录 一、Nacos 简介 1. 什么是 Nacos 2. Nacos 的核心功能 3. Nacos 的优势 二、Nacos 安装 1. 环境准备 2. 下载 Nacos 3. 解压安装包 4. 启动 Nacos 三、Nacos 使用 1. 服务注册与发现 (1)引入依赖 (2)配置 Nacos…

多线程JUC(二)

目录 一、等待唤醒机制1.生产者消费者2.阻塞队列3.线程的状态 二、线程池1.理解与使用2.自定义线程池 三、线程池额外知识 一、等待唤醒机制 1.生产者消费者 等待唤醒机制可以简单的理解为下图。厨师相当于生产者,吃货相当于消费者。当桌子(缓冲区&…

【仿muduo库one thread one loop式并发服务器实现】

文章目录 一、项目介绍1-1、项目总体简介1-2、项目开发环境1-3、项目核心技术1-4、项目开发流程1-5、项目如何使用 二、框架设计2-1、功能模块划分2-1-1、SERVER模块2-1-2、协议模块 2-2、项目蓝图2-2-1、整体图2-2-2、模块关系图2-2-2-1、Connection 模块关系图2-2-2-2、Accep…

关于tresos Studio(EB)的MCAL配置之GPT

概念 GPT,全称General Purpose Timer,就是个通用定时器,取的名字奇怪了点。定时器是一定要的,要么提供给BSW去使用,要么提供给OS去使用。 配置 General GptDeinitApi控制接口Gpt_DeInit是否启用 GptEnableDisable…

STM32Cubemx配置E22-xxxT22D lora模块实现定点传输

文章目录 一、STM32Cubemx配置二、定点传输**什么是定点传输?****定点传输的特点****定点传输的工作方式****E22 模块定点传输配置****如何启用定点传输?****示例** **应用场景****总结** **配置 1:C0 00 07 00 02 04 62 00 17 40****解析** …

多线程-线程本地变量ThreadLocal

简介 ThreadLocal是线程本地变量,用于存储独属于线程的变量,这些变量可以在同一个线程内跨方法、跨类传递。每一个ThreadLocal对象,只能为当前线程关联一个数据,如果要为当前线程关联多个数据,就需要使用多个ThreadLo…

Python练习(握手问题,进制转换,日期问题,位运算,求和)

一. 握手问题 代码实现 ans0for i in range(1,51):for j in range(i1,51):if i<7 and j<7:continueelse:ans 1print(ans) 这道题可以看成是50个人都握了手减去7个人没握手的次数 答案&#xff1a;1204 二.将十进制整数拆解 2.1门牌制作 代码实现 ans0for i in ra…

DeepSeek 角色设定与风格控制

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家&#xff0c;历代文学网&#xff08;PC端可以访问&#xff1a;https://literature.sinhy.com/#/?__c1000&#xff0c;移动端可微信小程序搜索“历代文学”&#xff09;总架构师&#xff0c;15年工作经验&#xff0c;精通Java编…

网络原理--HTTP协议

http中文名为超文本传输协议&#xff0c;所谓“超文本”就是指传输范围超出了能在UTF8等码表上找到的字符的范围&#xff0c;包含一些图片&#xff0c;特殊格式之类的。 HTTP的发展简介 从图中可以看出到现在已经发展出了HTTP3&#xff0c;但是市面上的主流还是以HTTP1.0为主。…

学习工具的一天之(burp)

第一呢一定是先下载 【Java环境】&#xff1a;Java Downloads | Oracle 下来是burp的下载 Download Burp Suite Community Edition - PortSwigger 【下载方法二】关注的一个博主 【BurpSuite 安装激活使用详细上手教程 web安全测试工具】https://www.bilibili.com/video/BV…

Java后端高频面经——Mysql

3. Mysql(21) 第三范式的作用与原理&#xff1f;&#xff08;B站&#xff09; 数据库范式有 3 种&#xff1a; 1NF(第一范式)&#xff1a;属性不可再分。 1NF 是所有关系型数据库的最基本要求 &#xff0c;也就是说关系型数据库中创建的表一定满足第一范式。 2NF(第二范式)&am…

React:Router路由

ReactRouter引入 在index.js里编辑&#xff0c;创建路由&#xff0c;绑定路由 import React from react; import ReactDOM from react-dom/client; import ./index.css; import reportWebVitals from ./reportWebVitals; import { createBrowserRouter, RouterProvider } from…