线性回归实现

1.从数据流水线、模型、损失函数、小批量随机梯度下降优化器

%matplotlib inline
import random
import torch
from d2l import torch as d2l

2.根据带有噪声的线性模型构造人造数据集。使用线性模型参数w =  [2,-3.4]T、b = 4.2和噪声项ε生成数据集及标签

y = Xw + b + ε

def synthetic_data(w, b, num_examples):"""生成 y = Xw + b + 噪声。"""X = torch.normal(0, 1, (num_examples, len(w)))y = torch.matmul(X, w) + by += torch.normal(0, 0.01, y.shape)return X, y.reshape((-1, 1))true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = synthetic_data(true_w, true_b, 1000)

3.features每行都包含二维数据样本,labels每行都包含一维标签值(标量)

print('features:', features[0], '\nlabel:', labels[0])

d2l.set_figsize()
d2l.plt.scatter(features[:, (1)].detach().numpy(),labels.detach().numpy(),1);

4.定义data_iter函数,该函数接收批量大小、特征矩阵、标签向量作为输入,生成大小为batch_size的小批量

def data_iter(batch_size, features, labels):num_examples = len(features)indices = list(range(num_examples))random.shuffle(indices)for i in range(0, num_examples, batch_size):batch_indices = torch.tensor(indices[i:min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]batch_size = 10for X, y in data_iter(batch_size, features, labels):print(X, '\n', y)break

5.定义初始化模型参数

w = torch.normal(0, 0.01, size = (2, 1),requires_grab = True)
b = torch.zeros(1, requires_grab = True)

6.定义模型

def linreg(X, w, b):"""线性回归模型"""return torch.matmul(X, w) + b

7.定义损失函数

def squared_loss(y_hat, y):"""均方损失。"""return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

8.定义优化算法

def sgd(params, lr, batch_size):"""小批量随机梯度下降。"""with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

9.训练过程

lr = 0.03
num_epochs = 3
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w, b], lr, batch_size)with torch.no_grad():train_1 = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_1.mean()):f}')

增加训练epoch,将num_epochs提高到10,观察损失值变化情况。

lr = 0.03
num_epochs = 10
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w, b], lr, batch_size)with torch.no_grad():train_1 = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_1.mean()):f}')

增大学习率(lr),将lr调高到10,看看损失值是否有显著变化。如果学习率太小,模型参数更新不明显,会导致损失值几乎不变。

lr = 10
num_epochs = 10
net = linreg
loss = squared_lossfor epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):l = loss(net(X, w, b), y)l.sum().backward()sgd([w, b], lr, batch_size)with torch.no_grad():train_1 = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_1.mean()):f}')

10. 学习率过高分析

    训练过程中每个 epoch 的损失值变成了 NaN(Not a Number)。这种情况通常是因为学习率过高,导致了梯度爆炸,使得参数更新变得不稳定,从而产生了 NaN。

    当学习率过高时,每次参数更新的步长就会非常大,这可能会导致模型参数变得异常大,从而使损失计算结果出现溢出或 NaN。

    在进行反向传播时,由于梯度过大,模型的更新会导致权重变得极端,从而使损失无法正常计算。如果损失函数中有平方或指数运算,那么过高的学习率会导致梯度变得异常大,进一步放大参数更新的幅度,从而导致模型训练不稳定。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/453379.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

windows 上验证请求接口是否有延迟

文件名:api_request_script.bat ,直接右键点击执行即可。 echo off setlocal:: 配置:: 替换为实际接口URL set "logFilelog.txt" set "errorLogFileerror_log.txt" set "interval3" :: 请求间隔(秒&#xff…

React之组件渲染性能优化

关键词: shouldComponentUpdate、PureComnent、React.memo、useMemo、useCallback shouldComponentUpdate 与 PureComnent shouldComponentUpdate 与 PureComnent 用于类组件。虽然官方推荐使用函数组件,但我们依然需要对类组件的渲染优化策略有所了解…

面经汇总——第一篇

1. int数据类型做了什么优化 Java在处理整数类型时,进行了多种优化,主要体现在编译器层面和JVM层面,目的是提高性能、减少内存开销。 常量池优化 Java中的Integer类有一个缓存机制,对于值在-128到127之间的int数字,Int…

springBoot集成nacos注册中心以及配置中心

一、安装启动nacos 访问&#xff1a;http://127.0.0.1:8848/nacos/index.html#/login 二、工程集成nacos 1、引入依赖 我这里搭建的父子工程哈&#xff0c;在子工程引入 <dependencies><!-- SpringBoot Web --><dependency><groupId>org.sp…

代码审计-Python Flask

1.Jinjia2模版注入 Flask是一个使用 Python 编写的轻量级 Web 应用框架。其 WSGI 工具箱采用 Werkzeug &#xff0c;模板引擎则使用 Jinja2。jinja2是Flask作者开发的一个模板系统&#xff0c;起初是仿django模板的一个模板引擎&#xff0c;为Flask提供模板支持&#xff0c;由于…

MySQL-30.索引-介绍

一.索引 为什么需要索引&#xff1f;当我们没有建立索引时&#xff0c;要在一张数据量极其庞大的表中查询表里的某一个值&#xff0c;会非常的消耗时间。以一个6000000数据量的表为例&#xff0c;查询一条记录的时间耗时约为13s&#xff0c;这是因为要查询符合某个值的数据&am…

RabbitMQ系列学习笔记(八)--发布订阅模式

文章目录 一、发布订阅模式原理二、发布订阅模式实战1、消费者代码2、生产者代码3、查看运行结果 本文参考&#xff1a; 尚硅谷RabbitMQ教程丨快速掌握MQ消息中间件rabbitmq RabbitMQ 详解 Centos7环境安装Erlang、RabbitMQ详细过程(配图) 一、发布订阅模式原理 在开发过程中&…

SpringBoot+MyBatis+MySQL项目基础搭建

一、新建项目 1.1 新建springboot项目 新建项目 选择SpringBoot&#xff0c;填写基本信息&#xff0c;主要是JDK版本和项目构建方式&#xff0c;此处以JDK17和Maven举例。 1.2 引入依赖 选择SpringBoot版本&#xff0c;勾选Lombok&#xff0c;Spring Web&#xff0c;MyBa…

UI自动化测试 —— web端元素获取元素等待实践!

前言 Web UI自动化测试是一种软件测试方法&#xff0c;通过模拟用户行为&#xff0c;自动执行Web界面的各种操作&#xff0c;并验证操作结果是否符合预期&#xff0c;从而提高测试效率和准确性。 目的&#xff1a; 确保Web应用程序的界面在不同环境(如不同浏览器、操作系统)下…

设计模式和软件框架的关系

设计模式和软件框架在软件开发中都有助于解决复杂问题和提高代码质量&#xff0c;但它们在概念和使用上存在一些区别。它们的关系可以通过以下几点理解&#xff1a; 层次与抽象程度 设计模式&#xff08;Design Patterns&#xff09;是一组通用的、可复用的解决方案&#xff0c…

完爆YOLOv10!Transformer+目标检测新算法性能无敌,狠狠拿捏CV顶会!

百度最近又搞了波大的&#xff0c;推出了一种全新的实时端到端目标检测算法RT-DETRv3&#xff0c;性能&耗时完爆YOLOv10。 RT-DETRv3基于Transformer设计&#xff0c;属于代表模型DETR的魔改进化版。这类目标检测模型都有着强大的扩展性与通用性&#xff0c;因为Transform…

MySQL—CRUD—进阶—(二) (ಥ_ಥ)

文本目录&#xff1a; ❄️一、新增&#xff1a; ❄️二、查询&#xff1a; 1、聚合查询&#xff1a; 1&#xff09;、聚合函数&#xff1a; 2&#xff09;、GROUP BY子句&#xff1a; 3&#xff09;、HAVING 子句&#xff1a; 2、联合查询&#xff1a; 1&#xff09;、内连接…

基于FPGA的以太网设计(五)

之前简单介绍并实现了ARP协议&#xff0c;今天简单介绍一下IP协议和ICMP协议。 1.IP协议 IP协议即Internet Protocol&#xff0c;是网络层的协议。 IP协议是TCP/IP协议族的核心协议&#xff0c;其主要包含两个方面&#xff1a; IP头部信息。IP头部信息出现在每个IP数据报中…

第13篇:无线与移动网络安全

目录 引言 13.1 无线网络的安全威胁 13.2 无线局域网的安全协议 13.3 移动通信中的安全机制 13.4 蓝牙和其他无线技术的安全问题 13.5 无线网络安全的最佳实践 13.6 总结 第13篇&#xff1a;无线与移动网络安全 引言 无线和移动网络的发展为我们的生活带来了极大的便利…

边缘计算与联邦学习:探索隐私保护和高效数据处理的结合

个人主页&#xff1a;chian-ocean 文章专栏 边缘计算与联邦学习&#xff1a;探索隐私保护和高效数据处理的结合 1. 引言 随着物联网(IoT)设备的普及&#xff0c;网络边缘产生了大量数据。将这些数据上传至云端进行集中式计算和处理&#xff0c;既有隐私泄露的风险&#xff…

15分钟学Go 实战项目一:命令行工具

实战项目一&#xff1a;命令行工具 1. 引言 命令行工具是开发者常用的工具之一&#xff0c;它可以帮助用户通过命令行界面对程序进行控制和交互。在这节中&#xff0c;我们将创建一个简单的命令行工具&#xff0c;以帮助你理解Go语言的基本语法和如何处理命令行输入。在这个过…

详解安卓和IOS的唤起APP的机制,包括第三方平台的唤起方法比如微信

网页唤起APP是一种常见的跨平台交互方式&#xff0c;它允许用户从网页直接跳转到移动应用程序。 这种技术广泛应用于各种场景&#xff0c;比如让用户在浏览器中点击链接后直接打开某个应用&#xff0c;或者从网页引导用户下载安装应用。实现这一功能主要依赖于URL Scheme、Univ…

ESP32-S3学习笔记:分区表(Partition Table)的二进制分析

一、参考资料 用于研究的官方示例代码&#xff1a;esp-idf-v5.3\examples\storage\partition_api\partition_find参考的官方文档&#xff1a;ESP-IDF编程指南&#xff1a;分区表 二、准备工作 用VS Code打开示例代码&#xff0c;打开示例代码的CSV自定义分区表&#xff0c;如…

大数据实验3: HDFS基础编程

实验3&#xff1a; HDFS基础编程 一、实验目的 HDFS的shell命令使用HDFS的JAVA API使用&#xff1b; 二、实验平台 操作系统&#xff1a;Linux&#xff08;Ubuntu16.04&#xff09;&#xff1b;Hadoop版本&#xff1a;3.3.1&#xff1b;JDK版本&#xff1a;1.8&#xff1b;…

498.对角线遍历

目录 题目解法代码说明&#xff1a;输出&#xff1a; 如何确定起始点&#xff1f;解释一下max(0,d−m1)是什么意思&#xff1f; 如何遍历对角线&#xff1f;.push_back是怎么用的&#xff1f; 题目 给你一个大小为 m x n 的矩阵 mat &#xff0c;请以对角线遍历的顺序&#xf…