001.从0开始实现线性回归(pytorch)

000动手从0实现线性回归

0. 背景介绍

我们构造一个简单的人工训练数据集,它可以使我们能够直观比较学到的参数和真实的模型参数的区别。
设训练数据集样本数为1000,输入个数(特征数)为2。给定随机生成的批量样本特征 X∈R1000×2
X∈R 1000×2 ,我们使用线性回归模型真实权重 w=[2,−3.4]⊤ 和偏差 b=4.2以及一个随机噪声项 ϵϵ 来生成标签
在这里插入图片描述

# 需要导入的包
import numpy as np
import torch
import random
from d2l import torch as d2l
from IPython import display
from matplotlib import pyplot as plt

1. 生成数据集合(待拟合)

使用python生成待拟合的数据

num_input = 2
num_example = 1000
w_true = [2,-3.4]
b_true = 4.2
features = torch.randn(num_example,num_input)
print('features.shape = '+ str(features.shape) )
labels =  w_true[0] * features[:,0] + w_true[1] * features[:,1] + b_true
labels += torch.tensor(np.random.normal(0,0.01 , size = labels.size() ),dtype = torch.float32)
print(features[0],labels[0])

2.数据的分批量处理

def data_iter(batch_size, features, labels):num_example = len(labels)indices = list(range(num_example))random.shuffle(indices)for i in range(0, num_example, batch_size):j = torch.tensor( indices[i:min(i+ batch_size,num_example)])yield features.index_select(0,j) ,labels.index_select(0,j)

3. 模型构建及训练

3.1 定义模型:

def linreg(X, w, b):return torch.mm(X,w)+b

3.2 定义损失函数

def square_loss(y, y_hat):return (y_hat - y.view(y_hat.size()))**2/2

3.3 定义优化算法

def sgd(params , lr ,batch_size):for param in params:param.data  -= lr * param.grad / batch_size

3.4 模型训练

# 设置超参数
lr = 0.03
num_epochs =5
net = linreg
loss = square_loss
batch_size = 10
for epoch in range(num_epochs):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):l = loss(net(X,w,b),y).sum()l.backward()sgd([w,b],lr,batch_size=batch_size)#梯度清零避免梯度累加w.grad.data.zero_()b.grad.data.zero_()train_l = loss(net(features,w,b),labels)print('epoch %d, loss %f' %(epoch +1 ,train_l.mean().item()))

epoch 1, loss 0.032550
epoch 2, loss 0.000133
epoch 3, loss 0.000053
epoch 4, loss 0.000053
epoch 5, loss 0.000053


基于pytorch的线性模型的实现

  1. 相关数据和初始化与上面构建相同
  2. 定义模型
import torch
from torch import nn
class LinearNet(nn.Module):def __init__(self, n_feature):# 调用父类的初始化super(LinearNet,self).__init__()# Linear(输入特征数,输出特征的数量,是否含有偏置项)self.linera = nn.Linear(n_feature,1)def forward(self,x):y = self.linera(x)return y
#打印模型的结构:
net = LinearNet(num_input)
print(net) 
# LinearNet( (linera): Linear(in_features=2, out_features=1, bias=True)
)
  1. 初始化模型的参数
from torch.nn import init
init.normal_(net.linera.weight,mean=0,std= 0.1)
init.constant_(net.linera.bias ,val=0)
  1. 定义损失函数
loss = nn.MSELoss()

5.定义优化算法

import torch.optim as optim
optimizer =  optim.SGD(net.parameters(),lr = 0.03)
print(optimizer)
  1. 训练模型:
num_epochs = 3
for epoch in range(1,num_epochs+1):for X,y in data_iter(batch_size= batch_size,features=features,labels= labels):output= net(X)l = loss(output,y.view(-1,1))optimizer.zero_grad()l.backward()optimizer.step()print('epoch %d ,loss: %f' %(epoch,l.item()) )

epoch 1 ,loss: 0.000159
epoch 2 ,loss: 0.000089
epoch 3 ,loss: 0.000066

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/430756.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第十四届蓝桥杯嵌入式国赛

一. 前言 本篇博客主要讲述十四届蓝桥杯嵌入式的国赛题目,包括STM32CubeMx的相关配置以及相关功能实现代码以及我在做题过程中所遇到的一些问题和总结收获。如果有兴趣的伙伴还可以去做做其它届的真题,可去 蓝桥云课 上搜索历届真题即可。 二. 题目概述 …

论文阅读与分析:Few-Shot Graph Learning for Molecular Property Prediction

论文阅读与分析:Few-Shot Graph Learning for Molecular Property Prediction 论文地址和代码地址1 摘要2 主要贡献3 基础知识Meta Learning1 介绍2 学习算法Step 1: What is learnable in a learning algorithm?Step 2:Define loss function for learn…

基于C语言开发(控制台)通讯录管理程序

通讯录程序设计 一、课程设计题目与要求 题目 :通讯录管理程序 1. 问题描述 编写一个简单的通讯录管理程序。通讯录记录有姓名,地址(省、市(县)、街道),电话号码,邮政编码等四项。2. 基本要求 程序应提供的基本基本管理功能有…

众数信科AI智能体政务服务解决方案——寻知智能笔录系统

政务服务解决方案 寻知智能笔录方案 融合民警口供录入与笔录生成需求 2分钟内生成笔录并提醒错漏 助办案人员二次询问 提升笔录质量和效率 寻知智能笔录系统 众数信科AI智能体 产品亮点 分析、理解行业知识和校验规则 AI实时提醒用户文书需注意部分 全文校验格式、内…

领域驱动DDD三种架构-分层架构、洋葱架构、六边形架构

博主介绍: 大家好,我是Yuperman,互联网宇宙厂经验,17年医疗健康行业的码拉松奔跑者,曾担任技术专家、架构师、研发总监负责和主导多个应用架构。 技术范围: 目前专注java体系,以及golang、.Net、…

(1999-2018年)全国各城市-财政收入–营业税

涵盖了1999年至2018年间,全国各城市的财政收入中营业税的部分。数据来源于中国区域统计年鉴及各省市统计年鉴 1999-2018年全国各城市-财政收入-营业税资源-CSDN文库https://download.csdn.net/download/2401_84585615/89504622 不同行业对营业税的贡献也存在差异。…

电动车车牌识别系统源码分享

电动车车牌识别检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer V…

Apache CVE-2021-41773 漏洞复现

1.打开环境 docker pull blueteamsteve/cve-2021-41773:no-cgid docker run -d -p 8080:80 97308de4753d 2.访问靶场 3.使用poc curl http://47.121.191.208:8080/cgi-bin/.%2e/.%2e/.%2e/.%2e/etc/passwd 4.工具验证

智能新突破:AIOT 边缘计算网关让老旧水电表图像识别

数字化高速发展的时代,AIOT(人工智能物联网)技术正以惊人的速度改变着我们的生活和工作方式。而其中,AIOT 边缘计算网关凭借其强大的功能,成为了推动物联网发展的关键力量。 这款边缘计算网关拥有令人瞩目的 1T POS 算…

自驾游拼团系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,发布起人管理,景点信息管理,景点分类管理,拼团旅游管理,参团信息管理,拼团订单管理,系统管理 微信端账号功…

11. DPO 微调示例:根据人类偏好优化LLM大语言模型

在部署大模型之后,我们必然要和微调打交道。现在大模型的微调有非常多的方法,过去的文章中提到的微调方法通常依赖于问题和答案对,标注成本较高。 2023 年所提出的 Direct Preference Optimization(DPO)为我们提供了一…

C语言----指针

基本知识点:指针的定义、指针运算符和指针运算等基本概念。重 点:字符指针、指针数组和多级指针。难 点:利用指针类型解决复杂的应用问题。 指针的概念 要点归纳 1.指针变量 在计算机中,所有数据都通过变量存放在内存中,每个变量都…

【matlab】将程序打包为exe文件(matlab r2023a为例)

文章目录 一、安装运行时环境1.1 安装1.2 简介 二、打包三、打包文件为什么很大 一、安装运行时环境 使用 Application Compiler 来将程序打包为exe,相当于你使用C编译器把C语言编译成可执行程序。 在matlab菜单栏–App下面可以看到Application Compiler。 或者在…

啤酒过滤——关于过滤助剂的介绍

在啤酒的酿造过程中,过滤是一个关键步骤,在啤酒厂中最常用的过滤助剂主要有两种:硅藻土和珍珠岩。它们能够帮助去除杂质,确保啤酒的清澈和口感。过滤助剂通常以粉状形式存在,它们被涂抹在过滤机的支撑材料上&#xff0…

深度合成算法备案和大模型备案的区别是什么

以下是关于大语言模型上线备案和深度合成算法备案区别的文档内容: 一、大语言模型上线备案与深度合成算法备案的区别 备案对象 大语言模型上线备案:主要针对生成式人工智能(AIGC)产品中的大型语言模型,能够生成文本、图…

MT6765/MT6762(R/D/M)/MT6761(MT8766)安卓核心板参数比较_MTK联发科4G智能模块

联发科Helio P35 MT6765安卓核心板 MediaTek Helio P35 MT6765是智能手机的主流ARM SoC,于2018年末推出。它在两个集群中集成了8个ARM Cortex-A53内核(big.LITTLE)。四个性能内核的频率高达2.3GHz。集成显卡为PowerVR GE8320,频率…

MATLAB系列09:图形句柄

MATLAB系列09:图形句柄 9. 图形句柄9.1 MATLAB图形系统9.2 对象句柄9.3 对象属性的检测和更改9.3.1 在创建对象时改变对象的属性9.3.2 对象创建后改变对象的属性 9.4 用 set 函数列出可能属性值9.5 自定义数据9.6 对象查找9.7 用鼠标选择对象9.8 位置和单位9.8.1 图…

Leetcode面试经典150题-39.组合总数进阶:40.组合总和II

本题是扩展题,真实考过,看这个题之前先看一下39题 Leetcode面试经典150题-39.组合总数-CSDN博客 给定一个候选人编号的集合 candidates 和一个目标数 target ,找出 candidates 中所有可以使数字和为 target 的组合。 candidates 中的每个数…

E2VPT: An Effective and Efficient Approach for Visual Prompt Tuning

论文汇总 存在的问题 1.以前的提示微调方法那样只关注修改输入,而应该明确地研究在微调过程中改进自注意机制的潜力,并探索参数效率的极限。 2.探索参数效率的极值来减少可调参数的数量? 解决办法 提示嵌入进行transformer中 提示剪枝 Token-wise …

004_动手实现MLP(pytorch)

import torch from torch import nn from torch.nn import init import numpy as np import sys import d2lzh_pytorch as d2l # 1.数据预处理 mnist_train torchvision.datasets.FashionMNIST(root/Users/w/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist, t…