深度学习 线性神经网络(线性回归 从零开始实现)

介绍:

在线性神经网络中,线性回归是一种常见的任务,用于预测一个连续的数值输出。其目标是根据输入特征来拟合一个线性函数,使得预测值与真实值之间的误差最小化。

线性回归的数学表达式为:
y = w1x1 + w2x2 + ... + wnxn + b

其中,y表示预测的输出值,x1, x2, ..., xn表示输入特征,w1, w2, ..., wn表示特征的权重,b表示偏置项。

训练线性回归模型的目标是找到最优的权重和偏置项,使得模型预测的输出与真实值之间的平方差(即损失函数)最小化。这一最优化问题可以通过梯度下降等优化算法来解决。

线性回归在深度学习中也被广泛应用,特别是在浅层神经网络中。在深度学习中,通过将多个线性回归模型组合在一起,可以构建更复杂的神经网络结构,以解决更复杂的问题。

 手动生成数据集:

%matplotlib inline
import torch
from d2l import torch as d2l
import random#"""生成y=Xw+b+噪声"""
def synthetic_data(w, b, num_examples):  #生成num_examples个样本X = d2l.normal(0, 1, (num_examples, len(w)))#随机x,长度为特征个数,权重个数y = d2l.matmul(X, w) + b#y的函数y += d2l.normal(0, 0.01, y.shape)#加上0~0.001的随机噪音return X, d2l.reshape(y, (-1, 1))#返回true_w = d2l.tensor([2, -3.4])#初始化真实w
true_b = 4.2#初始化真实bfeatures, labels = synthetic_data(true_w, true_b, 1000)#随机一些数据
print(features)
print(labels)

显示数据集:

print('features:', features[0],'\nlabel:', labels[0])'''
features: tensor([ 2.1714, -0.6891]) 
label: tensor([10.8673])
'''d2l.set_figsize()
d2l.plt.scatter(d2l.numpy(features[:, 1]), d2l.numpy(labels), 1);

读取小批量数据集:

#每次抽取一批量样本
def data_iter(batch_size, features, labels):#步长、特征、标签num_examples = len(features)#特征个数indices = list(range(num_examples))random.shuffle(indices)# 这些样本是随机读取的,没有特定的顺序,打乱顺序for i in range(0, num_examples, batch_size):#随机访问,步长为batch_sizebatch_indices = d2l.tensor(indices[i: min(i + batch_size, num_examples)])yield features[batch_indices], labels[batch_indices]

定义模型:

#定义模型
def linreg(X, w, b):  """线性回归模型"""return d2l.matmul(X, w) + b

定义损失函数:

#定义损失和函数
def squared_loss(y_hat, y):  #@save"""均方损失"""return (y_hat - d2l.reshape(y, y_hat.shape)) ** 2 / 2

定义优化算法(小批量随机梯度下降):

#定义优化算法  """小批量随机梯度下降"""
def sgd(params, lr, batch_size):  #参数、lr学习率、with torch.no_grad():for param in params:param -= lr * param.grad / batch_sizeparam.grad.zero_()

模型训练:

#训练
lr = 0.03#学习率
num_epochs = 3#数据扫三遍
net = linreg#模型
loss = squared_loss#损失函数
#初始化模型参数
w = torch.normal(0, 0.01, size=(2,1), requires_grad=True)#权重
b = torch.zeros(1, requires_grad=True)#b全赋为0for epoch in range(num_epochs):for X, y in data_iter(batch_size, features, labels):#拿出一批量x,yl = loss(net(X, w, b), y)  # X和y的小批量损失,实际的和预测的# 因为l形状是(batch_size,1),而不是一个标量。l中的所有元素被加到一起,# 并以此计算关于[w,b]的梯度l.sum().backward()sgd([w, b], lr, batch_size)  # 使用参数的梯度更新参数with torch.no_grad():train_l = loss(net(features, w, b), labels)print(f'epoch {epoch + 1}, loss {float(train_l.mean()):f}')
'''
epoch 1, loss 0.037302
epoch 2, loss 0.000140
epoch 3, loss 0.000048
'''print(f'w的估计误差: {true_w - d2l.reshape(w, true_w.shape)}')
print(f'b的估计误差: {true_b - b}')
'''
w的估计误差: tensor([0.0006, 0.0001], grad_fn=<SubBackward0>)
b的估计误差: tensor([-0.0003], grad_fn=<RsubBackward1>)
'''print(w)
'''
tensor([[ 1.9994],[-3.4001]], requires_grad=True)
'''print(b)
'''
tensor([4.2003], requires_grad=True)
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/284537.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode每日一题——统计桌面上的不同数字

统计桌面上的不同数字OJ链接&#xff1a;2549. 统计桌面上的不同数字 - 力扣&#xff08;LeetCode&#xff09; 题目&#xff1a; 思路&#xff1a; 这是一个很简单的数学问题&#xff1a; 当n 5时&#xff0c;因为n % 4 1&#xff0c;所以下一天4一定会被放上桌面 当n 4…

TnT-LLM: Text Mining at Scale with Large Language Models

TnT-LLM: Text Mining at Scale with Large Language Models 相关链接&#xff1a;arxiv 关键字&#xff1a;Large Language Models (LLMs)、Text Mining、Label Taxonomy、Text Classification、Prompt-based Interface 摘要 文本挖掘是将非结构化文本转换为结构化和有意义的…

QT(C++)-error LNK2038: 检测到“_ITERATOR_DEBUG_LEVEL”的不匹配项: 值“2”不匹配值“0”

1、项目场景&#xff1a; 在VS中采用QT&#xff08;C&#xff09;调试时&#xff0c;出现error LNK2038: 检测到“_ITERATOR_DEBUG_LEVEL”的不匹配项: 值“2”不匹配值“0”错误 2、解决方案&#xff1a; 在“解决方案资源管理器”中选中出现此类BUG的项目&#xff0c;右键-…

【NLP笔记】预训练+微调范式之OpenAI Transformer、ELMo、ULM-FiT、Bert..

文章目录 OpenAI TransformerELMoULM-FiTBert基础结构Embedding预训练&微调 【原文链接】&#xff1a; BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding 【本文参考链接】 The Illustrated BERT, ELMo, and co. (How NLP Cracked Tra…

提面 | 面试抽题

学习到更新日期面试抽题-1.2案例分析的思维本质2024-3-23 1提面抽屉论述问题的分类 1.1案例分析占总论 1.2案例分析的思维本质

计算机网络相关

OSI七层模型 各层功能&#xff1a; TCP/IP四层模型 应用层 传输层 网络层 网络接口层 访问一个URL的全过程 在浏览器中输入指定网页的 URL。 浏览器通过 DNS 协议&#xff0c;获取域名对应的 IP 地址。 浏览器根据 IP 地址和端口号&#xff0c;向目标服务器发起一个 TCP…

【进阶五】Python实现SDVRP(需求拆分)常见求解算法——离散粒子群算法(DPSO)

基于python语言&#xff0c;采用经典离散粒子群算法&#xff08;DPSO&#xff09;对 需求拆分车辆路径规划问题&#xff08;SDVRP&#xff09; 进行求解。 目录 往期优质资源1. 适用场景2. 代码调整3. 求解结果4. 代码片段参考 往期优质资源 经过一年多的创作&#xff0c;目前已…

linux文本三剑客 --- grep、sed、awk

1、grep grep&#xff1a;使用正则表达式搜索文本&#xff0c;将匹配的行打印出来&#xff08;匹配到的标红&#xff09; 命令格式&#xff1a;grep [option] pattern file <1> 命令参数 -A<显示行数>&#xff1a;除了显示符合范本样式的那一列之外&#xff0c;并…

C语言中的联合体和枚举

联合体 联合体的创建 联合体的关键字是union union S {char a;int i; };除了关键字和结构体不一样之外&#xff0c;联合体的创建语法形式和结构体的很相似&#xff0c;如果不熟悉结构体的创建&#xff0c;可以看一下我上一篇的博客关于结构体知识的详解。 联合体的特点 联合…

Personal Website

Personal Website Static Site Generators hexo hugo jekyll Documentation Site Generator gitbook vuepress vitepress docsify docute docusaurus Deployment 1. GitHub Pages 2. GitLab Pages 3. vercel 4. netlify Domain 域名注册 freessl 域名解析域名…

【GUI】自动化办公

目录 一、GUI介绍 二、环境安装 三、鼠标移动操作 四、鼠标点击操作 五、拖动鼠标 六、鼠标滚动操作 七、屏幕快照&图像识别基础 7.1 屏幕快照&#xff08;截图&#xff09; 7.2 图像识别 八、键盘控制 一、GUI介绍 GUI自动化就是写程序直接控制键盘和鼠标。这些…

电脑如何关闭自启动应用?cmd一招解决问题

很多小伙伴说电脑刚开机就卡的和定格动画似的&#xff0c;cmd一招解决问题&#xff1a; CtrlR打开cmd,输入&#xff1a;msconfig 进入到这个界面&#xff1a; 点击启动&#xff1a; 打开任务管理器&#xff0c;禁用不要的自启动应用就ok了

【prometheus-operator】k8s监控集群外redis

1、部署exporter GitHub - oliver006/redis_exporter: Prometheus Exporter for Redis Metrics. Supports Redis 2.x, 3.x, 4.x, 5.x, 6.x, and 7.x redis_exporter-v1.57.0.linux-386.tar.gz # 解压 tar -zxvf redis_exporter-v1.57.0.linux-386.tar.gz # 启动 nohup ./redi…

SpringCloud-记

目录 什么是SpringCloud 什么是微服务 SpringCloud的优缺点 SpringBoot和SpringCloud的区别 RPC 的实现原理 RPC是什么 eureka的自我保护机制 Ribbon feigin优点 Ribbon和Feign的区别 什么是SpringCloud Spring Cloud是一系列框架的有序集合。它利用Spring Boot的开发…

力扣面试150 阶乘后的零 数论 找规律 质因数

Problem: 172. 阶乘后的零 思路 &#x1f468;‍&#x1f3eb; 大佬神解 一个数末尾有多少个 0 &#xff0c;取决于这个数 有多少个因子 10而 10 可以分解出质因子 2 和 5而在阶乘种&#xff0c;2 的倍数会比 5 的倍数多&#xff0c;换而言之&#xff0c;每一个 5 都会找到一…

AI时代Python金融大数据分析实战:ChatGPT让金融大数据分析插上翅膀

❤️作者主页&#xff1a;小虚竹 ❤️作者简介&#xff1a;大家好,我是小虚竹。2022年度博客之星评选TOP 10&#x1f3c6;&#xff0c;Java领域优质创作者&#x1f3c6;&#xff0c;CSDN博客专家&#x1f3c6;&#xff0c;华为云享专家&#x1f3c6;&#xff0c;掘金年度人气作…

电脑照片分辨率怎么调?这款dpi修改工具好用

许多考试平台在上传证件照片的时候&#xff0c;大多都会对图片分辨率有具体要求&#xff0c;但是如果遇上手上的图片分辨率达不到要求&#xff0c;那么怎么改图片分辨率呢&#xff1f;可以利用专业的dpi修改工具来处理&#xff0c;比如今天分享的就是一个在线修改图片分辨率的方…

唯众物联网安装调试员实训平台物联网一体化教学实训室项目交付山东技师学院

近日&#xff0c;山东技师学院物联网安装调试员实训平台及物联网一体化教学实训室采购项目已顺利完成交付并投入使用&#xff0c;标志着学院在物联网技术教学与实践应用方面迈出了坚实的一步。 山东技师学院作为国内知名的技师培养摇篮&#xff0c;一直以来致力于为社会培养高…

select , poll, epoll思维导图

目录 1. 总的框架结构 2. select 3. poll 4. epoll 1. 总的框架结构 2. select

自动驾驶轨迹规划之时空语义走廊(一)

欢迎大家关注我的B站&#xff1a; 偷吃薯片的Zheng同学的个人空间-偷吃薯片的Zheng同学个人主页-哔哩哔哩视频 (bilibili.com) 目录 1.摘要 2.系统架构 3.MPDM 4.时空语义走廊 ​4.1 种子生成 4.2 具有语义边界的cube inflation ​4.3 立方体松弛 本文解析了丁文超老师…