PyTorch快速入门教程【小土堆】之完整模型训练套路

视频地址完整的模型训练套路(一)_哔哩哔哩_bilibili

import torch
import torchvision
from model import *
from torch import nn
from torch.utils.data import DataLoader# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="CIFAR10", train=False, transform=torchvision.transforms.ToTensor(),download=True)
# Length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10,训练数据集的长度为:10
# print("训练数据集的长度为: {}".format(train_data_size))
# print("测试数据集的长度为: {}".format(test_data_size))# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
tudui = Tudui()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10for i in range(epoch):print("--------第{}轮训练开始---------".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))# 测试步骤开始total_test_loss =0total_accuracy = 0with torch.no_grad():# 保证不会调优for data in test_dataloader:imgs, targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss: {}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy / test_data_size))torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/500449.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

若依框架之简历pdf文档预览功能

一、前端 (1)安装插件vue-pdf:npm install vue-pdf (2)引入方式:import pdf from "vue-pdf"; (3)components注入方式:components:{pdf} (4&…

永磁同步电机负载估计算法--自适应龙伯格观测器

一、原理介绍 龙贝格扰动观测器的参数可以通过带宽配置法进行整定,将观测器带宽设为L,选取大的L可以加快观测器的收敛速度,但是L过大会导致系统阶跃响应出现超调、稳态性能差等问题。因此,在龙贝格观测器中引入表征系统状态变量x…

Python机器学习笔记(十七、分箱、离散化、线性模型与树)

数据表示的最佳方法:取决于数据的语义,所使用的模型种类。 线性模型与基于树的模型(决策树、梯度提升树和随机森林)是两种成员很多同时又非常常用的模 型,它们在处理不同的特征表示时就具有非常不同的性质。我们使用w…

Spring Boot介绍、入门案例、环境准备、POM文件解读

文章目录 1.Spring Boot(脚手架)2.微服务3.环境准备3.1创建SpringBoot项目3.2导入SpringBoot相关依赖3.3编写一个主程序;启动Spring Boot应用3.4编写相关的Controller、Service3.5运行主程序测试3.6简化部署 4.Hello World探究4.1POM文件4.1.1父项目4.1.2父项目的父…

嵌入式入门Day35

网络编程 Day2 套接字socket基于TCP通信的流程服务器端客户端TCP通信API 基于UDP通信的流程服务器端客户端 作业 套接字socket socket套接字本质是一个特殊的文件,在原始的Linux中,它和管道,消息队列,共享内存,信号等…

安卓系统主板_迷你安卓主板定制开发_联发科MTK安卓主板方案

安卓主板搭载联发科MT8766处理器,采用了四核Cortex-A53架构,高效能和低功耗设计。其在4G网络待机时的电流消耗仅为10-15mA/h,支持高达2.0GHz的主频。主板内置IMG GE832 GPU,运行Android 9.0系统,内存配置选项丰富&…

centos,789使用mamba快速安装R及语言包devtools

如何进入R语言运行环境请参考:Centos7_miniconda_devtools安装_R语言入门之R包的安装_r语言devtools包怎么安装-CSDN博客 在R里面使用安装devtools经常遇到依赖问题,排除过程过于费时,使用conda安装包等待时间长等。下面演示centos,789都是一…

人工智能(AI)简史:推动新时代的科技力量

一、人工智能简介 人工智能(AI,Artificial Intelligence)是计算机科学的一个分支,旨在研究和开发可以模拟、扩展或增强人类智能的系统。它涉及多种技术和方法,包括机器学习、深度学习、自然语言处理(NLP&a…

【笔记】在虚拟机中通过apache2给一个主机上配置多个web服务器

(配置出来的web服务器又叫虚拟主机……) 下载apache2 sudo apt update sudo apt install apache2 (一)ip相同 web端口不同的web服务器 进入 /var/www/html 创建站点一和站点二的目录文件(目录文件名自定义哈&#x…

linux装git

前言 以 deepin 深度系统为例,安装命 令行版 Git 非常简单。 安装 注意:需要输入账号密码,否则无法进行。 打开终端,执行如下命令即可。 sudo apt-get install git成功 如下图所示,输入 git ,命令识别即…

【Spark】架构与核心组件:大数据时代的必备技能(下)

🐇明明跟你说过:个人主页 🏅个人专栏:《大数据前沿:技术与应用并进》🏅 🔖行路有良友,便是天堂🔖 目录 一、引言 1、什么是Apache Spark 2、Spark 的应用场景&…

NLP中的神经网络基础

一:多层感知器模型 1:感知器 解释一下,为什么写成 wxb>0 ,其实原本是 wx > t ,t就是阈值,超过这个阈值fx就为1,现在把t放在左边。 在感知器里面涉及到两个问题: 第一个,特征提…

第十一章 图论

题目描述: 阿里这学期修了计算机组织和架构课程。他了解到指令之间可能存在依赖关系,比如WAR(读后写)、WAW、RAW。 如果两个指令之间的距离小于安全距离,则会导致危险,从而可能导致错误的结果。因此&#…

嵌入式系统 第七讲 ARM-Linux内核

• 7.1 ARM-Linux内核简介 • 内核:是一个操作系统的核心。是基于硬件的第一层软件扩充, 提供操作系统的最基本的功能,是操作系统工作的基础,它负责管理系统的进程、内存、设备驱动程序、文件和网络系统, 决定着系统的…

win11蓝屏死机 TPM-WMI

1. 打开win11的事件查看器,定位错误 最近两次都是 KB5016061:安全启动数据库和 DBX 变量更新事件 - Microsoft 支持 事件源 TPM-WMI 事件 ID 1796 2. 解决方案 打开注册表:计算机\HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Contro…

Linux命令——3.网络与用户

文章目录 一、网络1.网络测试与诊断2.网络接口配置3.无线网络配置4.防火墙与网络管理6.防火墙管理1)firewalld命令2)iptables命令 二、用户和群组1.管理员模式2.用户账户管理1)useradd创建2)usermod修改3)userdel 删除…

机器学习算法基础知识1:决策树

机器学习算法基础知识1:决策树 一、本文内容与前置知识点1. 本文内容2. 前置知识点 二、场景描述三、决策树的训练1. 决策树训练方式(1)分类原则-Gini(2)分类原则-entropy(3)加权系数-样本量&am…

_使用CLion的Vcpkg安装SDL2,添加至CMakelists时报错,编译报错

语言:C20 编译器:gcc 14.2 摘要:初次使用Vcpkg添加SDL2,出现CMakelists找不到错误、编译缺失main错误、运行失败错误。 CMakelists缺失错误: 使用CLion的Vcpkg安装SDL2时,按照指示把对应代码添加至CMakel…

Lumos学习王佩丰Excel第二十二讲:制作甘特图与动态甘特图

一、制作双向条形图 1. 分离坐标轴 2. 自定义坐标轴数字格式:加分号加正常数字 3. 修改图表背景 修改图片艺术效果:虚化图片 二、制作甘特图 1、甘特图定义 甘特图(Gantt chart)又称为横道图、条状图(Bar chart&…

el-pagination 为什么只能展示 10 条数据(element-ui@2.15.13)

好的&#xff0c;我来帮你分析前端为什么只能展示 10 条数据&#xff0c;以及如何解决这个问题。 问题分析&#xff1a; pageSize 的值&#xff1a; 你的 el-pagination 组件中&#xff0c;pageSize 的值被设置为 10&#xff1a;<el-pagination:current-page"current…