【深度学习实验】线性模型(三):使用Pytorch实现简单线性模型:搭建、构造损失函数、计算损失值

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入库

1. 定义线性模型linear_model

2. 定义损失函数loss_function

3. 定义数据

4. 调用模型

5. 完整代码


一、实验介绍

  • 使用Pytorch实现
    • 线性模型搭建
    • 构造损失函数
    • 计算损失值

 二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        线性模型是一种基本的机器学习模型,用于建立输入特征与输出之间的线性关系。它是一种线性组合模型,通过对输入特征进行加权求和,再加上一个偏置项,来预测输出值。

        线性模型的一般形式可以表示为:y = w1x1 + w2x2 + ... + wnxn + b,其中y是输出变量,x1, x2, ..., xn是输入特征,w1, w2, ..., wn是特征的权重,b是偏置项。模型的目标是通过调整权重和偏置项,使预测值与真实值之间的差异最小化。

线性模型有几种常见的应用形式:

  1. 线性回归(Linear Regression):用于建立输入特征与连续输出之间的线性关系。它通过最小化预测值与真实值的平方差来拟合最佳的回归直线。

  2. 逻辑回归(Logistic Regression):用于建立输入特征与二分类或多分类输出之间的线性关系。它通过使用逻辑函数(如sigmoid函数)将线性组合的结果映射到概率值,从而进行分类预测。

  3. 支持向量机(Support Vector Machines,SVM):用于二分类和多分类问题。SVM通过找到一个最优的超平面,将不同类别的样本分隔开。它可以使用不同的核函数来处理非线性问题。

  4. 岭回归(Ridge Regression)和Lasso回归(Lasso Regression):用于处理具有多重共线性(multicollinearity)的回归问题。它们通过对权重引入正则化项,可以减小特征的影响,提高模型的泛化能力。

        线性模型的优点包括简单、易于解释和计算效率高。它们在许多实际问题中都有广泛的应用。然而,线性模型也有一些限制,例如对非线性关系的建模能力较弱。在处理复杂的问题时,可以通过引入非线性特征转换或使用核函数进行扩展,以提高线性模型的性能。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

0. 导入库

import torch

1. 定义线性模型linear_model

        该函数接受输入数据x,使用随机生成的权重w和偏置b,计算输出值output。这里的线性模型的形式为 output = x * w + b

def linear_model(x):w = torch.rand(1, 1, requires_grad=True)b = torch.randn(1, requires_grad=True)return torch.matmul(x, w) + b

2. 定义损失函数loss_function

      这里使用的是均方误差(MSE)作为损失函数,计算预测值与真实值之间的差的平方。

def loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return loss

3. 定义数据

  • 生成一个随机的输入张量 x,形状为 (5, 1),表示有 5 个样本,每个样本的特征维度为 1。

  • 生成一个目标张量 y,形状为 (5, 1),表示对应的真实标签。

  • 打印数据的信息,包括每个样本的输入值x和目标值y
x = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):print("Item " + str(i), "x:", x[i][0], "y:", y[i])

4. 调用模型

  • 使用 linear_model 函数对输入 x 进行预测,得到预测结果 prediction

  • 使用 loss_function 计算预测结果与真实标签之间的损失,得到损失张量 loss

  • 打印了每个样本的损失值。
prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):print("Item ", str(i), "Loss:", loss[i])

5. 完整代码

import torchdef linear_model(x):w = torch.rand(1, 1, requires_grad=True)b = torch.randn(1, requires_grad=True)return torch.matmul(x, w) + bdef loss_function(y_true, y_pred):loss = (y_pred - y_true) ** 2return lossx = torch.rand(5, 1)
y = torch.tensor([1, -1, 1, -1, 1], dtype=torch.float32).view(-1, 1)
print("The data is as follows:")
for i in range(x.shape[0]):print("Item " + str(i), "x:", x[i][0], "y:", y[i])prediction = linear_model(x)
loss = loss_function(y, prediction)
print("The all loss value is:")
for i in range(len(loss)):print("Item ", str(i), "Loss:", loss[i])


注意:

        本实验的线性模型仅简单地使用随机权重和偏置,计算了模型在训练集上的均方误差损失,没有使用优化算法进行模型参数的更新。

        通常情况下会使用梯度下降等优化算法来最小化损失函数,并根据训练数据不断更新模型的参数,具体内容请听下回分解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/134962.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TensorFlow与pytorch特定版本虚拟环境的安装

TensorFlow与Python的版本对应,注意,一定要选择对应的版本,否则会让你非常痛苦,折腾很久搞不清楚原因。 建议使用国内镜像源安装 没有GPU后缀的就表示是CPU版本的,不加版本就是最新 pip install tensorflow -i https:…

Learn Prompt-人工智能基础

什么是人工智能?很多人能举出很多例子说这就是人工智能,但是让我们给它定义一个概念大家又觉得很难描述的清楚。实际上,人工智能并不是计算机科学领域专属的概念,在其他学科包括神经科学、心理学、哲学等也有人工智能的概念以及相…

Vue3+ElementUI使用

<!DOCTYPE html> <html> <head><meta charset"UTF-8"><meta name"viewport" content"initial-scale1.0,maximum-scale1.0,minimum-scale1.0,user-scalable0, widthdevice-width"/><!-- 引入样式 --><lin…

《C和指针》笔记24: 指针和间接访问

本文主要讲指针和间接访问&#xff0c;标题对应《C和指针对应的章节》&#xff0c;引用的地方是自己写的一些注释、理解和总结。 指针、间接访问和左值 先回顾一下左值和右值 左值代表着一个位置。右值代表着一个值。赋值等号左边是个左值&#xff0c;赋值等号右边是一个右值…

Vue入门简介(带你打开Vue的大门)

目录 前言 一、Vue简介 1. 什么是Vue 2. Vue的应用场景 3. Vue的作用&#xff08;重要性&#xff09; 4. 什么是MVVM模式 5. 开源库网址 二、Vue入门使用 1. 基础使用步骤 1.1 引入Vue.js 1.2 创建Vue实例 1.3 编写Vue模板 1.4 数据绑定与指令 1.5 调用Vue方法和…

flutter聊天界面-TextField输入框buildTextSpan实现@功能展示高亮功能

flutter聊天界面-TextField输入框buildTextSpan实现功能展示高亮功能 最近有位朋友讨论的时候&#xff0c;提到了输入框的高亮展示。在flutter TextField中需要插入特殊样式的标签&#xff0c;比如&#xff1a;“请 张三 回答一下”&#xff0c;这一串字符在TextField中输入&a…

群辉 Synology NAS Docker 安装 RustDesk-server 自建服务器只要一个容器

from https://blog.zhjh.top/archives/M8nBI5tjcxQe31DhiXqxy 简介 之前按照网上的教程&#xff0c;rustdesk-server 需要安装两个容器&#xff0c;最近想升级下版本&#xff0c;发现有一个新镜像 rustdesk-server-s6 可以只安装一个容器。 The S6-overlay acts as a supervi…

NSS [HNCTF 2022 WEEK2]ohmywordpress(CVE-2022-0760)

NSS [HNCTF 2022 WEEK2]ohmywordpress&#xff08;CVE-2022-0760&#xff09; 题目描述&#xff1a;flag在数据库里面。 开题&#xff1a; 顺着按钮一直点下去会发现出现一个按钮叫安装WordPress 安装完之后的界面&#xff0c;有一个搜索框。 F12看看network。 又出现了这个…

day22集合01

1.Collection集合 1.1数组和集合的区别【理解】 相同点 都是容器,可以存储多个数据 不同点 数组的长度是不可变的,集合的长度是可变的 数组可以存基本数据类型和引用数据类型 集合只能存引用数据类型,如果要存基本数据类型,需要存对应的包装类 1.2集合类体系结构【理解】…

【Axure原型素材】扫一扫

今天和粉丝们免费分享扫一扫的原型素材&#xff0c;"扫一扫"是一项常见的移动应用功能&#xff0c;通常通过手机或平板电脑上的摄像头来扫描二维码或条形码以实现各种功能。下面是和大家分享扫一扫的常用素材~~~ 【原型效果】 【Axure原型素材】扫一扫 【原型预览】…

【环境配置】基于Docker配置Chisel-Bootcamp环境

文章目录 Chisel是什么Chisel-Bootcamp是什么基于Docker配置Chisel-Bootcamp官网下载Docker安装包Docker换源启动Bootcamp镜像常用docker命令 可能产生的问题 Chisel是什么 Chisel是Scala语言的一个库&#xff0c;可以由Scala语言通过import引入。 Chisel编程可以生成Verilog代…

IDEA创建完Maven工程后,右下角一直显示正在下载Maven插件

原因&#xff1a; 这是由于新建的Maven工程&#xff0c;IDEA会用它内置的默认的Maven版本&#xff0c;使用国外的网站下载Maven所需的插件&#xff0c;速度很慢 。 解决方式&#xff1a; 每次创建 Project 后都需要设置 Maven 家目录位置&#xff08;就是我们自己下载的Mav…

嵌入式:驱动开发 Day2

作业&#xff1a;字符设备驱动&#xff0c;完成三盏LED灯的控制 驱动代码&#xff1a; mychrdev.c #include <linux/init.h> #include <linux/module.h> #include <linux/fs.h> #include <linux/uaccess.h> #include <linux/io.h> #include &q…

Windows/Linux(命令、安装包和源码安装)平台各个版本QT详细安装教程

前言 本文章主要介绍了Windows/Linux平台下&#xff0c;QT4&#xff0c;QT5&#xff0c;QT6的安装步骤。为什么要把QT版本分开介绍呢&#xff0c;因为这三个版本&#xff0c;安装步骤都不一样。Windows平台&#xff0c;QT4的Qt Creator&#xff0c;QT库和编译器是分开的&#…

Android Studio 报错问题记录

工具地址 由于之前手贱不知道点了一个什么东西更新&#xff0c;导致一个code1报错&#xff0c;后来又一通瞎比操作直接吧Android Studio弄得打不开模拟器了&#xff0c;所以我后面就全部卸载重新安装了一下&#xff0c;并把之前遇到的问题做下记录&#xff0c;可能并不适用于每…

笔记1.5:计算机网络体系结构

从功能上描述计算机网络结构 分层结构 每层遵循某个网络协议完成本层功能 基本概念 实体&#xff1a;表示任何可发送或接收信息的硬件或软件进程。 协议是控制两个对等实体进行通信的规则的集合&#xff0c;协议是水平的。 任一层实体需要使用下层服务&#xff0c;遵循本层…

自定义实现简易版ArrayList

文章目录 1.了解什么是顺序表2.实现哪些功能3.初始化ArrayList4.实现功能接口遍历顺序表判断顺序表是否已满添加元素指定下标添加元素自定义下标不合法异常判断顺序表是否为空查找指定元素是否存在查找指定元素返回下标获取指定下标的元素顺序表为空异常修改指定下标元素的值删…

Stable DIffusion 炫酷应用 | AI嵌入艺术字+光影光效

目录 1 生成AI艺术字基本流程 1.1 生成黑白图 1.2 启用ControlNet 参数设置 1.3 选择大模型 写提示词 2 不同效果组合 2.1 更改提示词 2.2 更改ControlNet 2.2.1 更改模型或者预处理器 2.2.2 更改参数 3. 其他应用 3.1 AI光影字 本节需要用到ControlNet&#xff0c;可…

手摸手系列之前端Vue实现PDF预览及打印的终极解决方案

前言 近期我正在开发一个前后端分离项目&#xff0c;使用了Spring Boot 和 Vue2&#xff0c;借助了国内优秀的框架 jeecg&#xff0c;前端UI库则选择了 ant-design-vue。在项目中&#xff0c;需要实现文件上传功能&#xff0c;同时还要能够在线预览和下载图片和PDF文件&#x…

[golang 流媒体在线直播系统] 4.真实RTMP推流摄像头把摄像头拍摄的信息发送到腾讯云流媒体服务器实现直播

用RTMP推流摄像头把摄像头拍摄的信息发送到腾讯云流媒体服务器实现直播,该功能适用范围广,比如:幼儿园直播、农场视频直播, 一.准备工作 要实现上面的功能,需要准备如下设备: 推流摄像机&#xff08;监控&#xff09; 流媒体直播服务器(腾讯云流媒体服务器,自己搭建的流媒体服务…