使用Pytorch实现linear_regression

使用Pytorch实现线性回归

# import necessary packages
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# Set necessary Hyper-parameters.
input_size = 1
output_size = 1
num_epochs = 60
learning_rate = 0.001
# Define a Toy dataset.
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], [9.779], [6.182], [7.59], [2.167], [7.042], [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], [3.366], [2.596], [2.53], [1.221], [2.827], [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)# Confirm the data shape.
print(x_train.shape, y_train.shape)
(15, 1) (15, 1)
# Linear regression model
model = nn.Linear(input_size, output_size)
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, )
# Train the model
for epoch in range(num_epochs):# Convert numpy arrays to torch tensorsinputs = torch.from_numpy(x_train)targets = torch.from_numpy(y_train)# Forward passoutputs = model(inputs)loss = criterion(outputs, targets)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()# Set an output counterif (epoch+1) % 5 == 0:print('Epoch [{}/{}], loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))# Plot the graph
predicted = model(torch.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original data')
plt.plot(x_train, predicted, label='Fitted line')
plt.legend()
plt.show()
Epoch [5/60], loss: 7.1598
Epoch [10/60], loss: 3.0717
Epoch [15/60], loss: 1.4154
Epoch [20/60], loss: 0.7443
Epoch [25/60], loss: 0.4722
Epoch [30/60], loss: 0.3618
Epoch [35/60], loss: 0.3169
Epoch [40/60], loss: 0.2985
Epoch [45/60], loss: 0.2909
Epoch [50/60], loss: 0.2876
Epoch [55/60], loss: 0.2861
Epoch [60/60], loss: 0.2853

在这里插入图片描述

# Save the model checkpoint
torch.save(model.state_dict(), 'model_param.ckpt')
torch.save(model, 'model.ckpt')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/201029.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于单片机仓库温湿度监测报警系统仿真设计

**单片机设计介绍,基于单片机仓库温湿度监测报警系统仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的仓库温湿度监测报警系统可以被设计成能够实时监测仓库内的温度和湿度,并根据预设…

【C++】继承(上) 继承 | 子类的默认成员函数

一、继承 概念 继承(inheritance)是一种面向对象编程的概念,它允许一个类(称为子类或派生类)继承另一个类(称为父类或基类)的特征和行为。子类可以获得父类的成员函数和变量,而不需要重新编写它们。子类还…

这是一棵适合搜索二叉树

🎈个人主页:🎈 :✨✨✨初阶牛✨✨✨ 🐻强烈推荐优质专栏: 🍔🍟🌯C的世界(持续更新中) 🐻推荐专栏1: 🍔🍟🌯C语言初阶 🐻推荐专栏2: 🍔…

宏集新闻 | 虹科传感器事业部正式更名为宏集科技

致一直支持“虹科传感器”的朋友们: 为进一步整合资源,给您带来更全面、更优质的服务,我们非常荣幸地宣布,虹科传感器事业部已正式更名为宏集科技。这一重要的改变代表了虹科持续发展进程中的新里程碑,也体现了我们在传…

安装gitlab

安装gitlab 环境 关闭防火墙以及selinux,起码4核8G 内存至少 3G 不然启动不了 下载环境 gitlab官网:GitLab下载安装_GitLab最新中文基础版下载安装-极狐GitLab rpm包下载地址: [Yum - Nexus Repository Manager (gitlab.cn)](https://pack…

Android studio run 手机或者模拟器安装失败,但是生成了debug.apk

错误信息如下:Error Installation did not succeed. The application could not be installed:List of apks 出现中文乱码; 我首先尝试了打包,能正常安装,再次尝试了debug的安装包,也正常安装&#xff1…

报错!Jupyter notebook 500 : Internal Server Error

Jupyter notebook 报错 500 : Internal Server Error 问题背景 tensorflow-gpu环境,为跑特定代码专门开了一个环境,使用conda安装了Jupyter notebook,能够在浏览器打开Jupyter notebook,但是notebook打开ipynb会报错。 问题分析…

抖音seo矩阵系统源代码部署及产品功能设计分析

一、引言 随着抖音等短视频平台的崛起,越来越多的企业和个人开始关注如何在这些平台上提升曝光量和用户流量。抖音SEO(搜索引擎优化)是一种有效的方法,通过优化短视频内容和关键词,让更多的人找到并点击你的视频。本文…

【css】Google第三方登录按钮样式修改

文章目录 场景前置准备修改样式官方属性修改样式CSS修改样式按钮的高度height和border-radiusLogo和文字布局 场景 需要用到谷歌的第三方登录,登录按钮有自己的样式。根据官方文档:概览 | Authentication | Google for Developers,提供两种第…

【Web】Ctfshow Nodejs刷题记录

目录 ①web334 ②web335 ③web336 ④web337 ⑤web338 ⑥web339 ⑦web340 ⑧web341 ⑨web342-343 ⑩web344 ①web334 进来是一个登录界面 下载附件,简单代码审计 表单传ctfshow 123456即可 ②web335 进来提示 get上传eval参数执行nodejs代码 payload: …

基于安卓android微信小程序的个人管理小程序

运行环境 开发语言:Java 框架:ssm JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7(一定要5.7版本) 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包&a…

B站短视频如何去水印?一键解析下载B站视频!

在浏览B站视频时,我们有时会遇到带有水印的场景。这些水印可能会干扰我们对视频内容的观看体验,特别是在全屏观看时。此外,当我们想要保存或分享这些视频时,水印也会成为一种障碍。因此,去除水印的需求就变得非常迫切。…

035、目标检测-物体和数据集

之——物体检测和数据集 目录 之——物体检测和数据集 杂谈 正文 1.目标检测 2.目标检测数据集 3.目标检测和边界框 4.目标检测数据集示例 杂谈 目标检测是计算机视觉中应用最为广泛的,之前所研究的图片分类等都需要基于目标检测完成。 在图像分类任务中&am…

wsl安装ubuntu的问题点、处理及连接

WSL安装Ubuntu的参考链接 (41条消息) wsl报错:WslRegisterDistribution failed with error: 0x800701bc_yzpyzp的博客-CSDN博客_0x800701bc wsl (41条消息) 使用Ubuntu安装软件出现Unable to locate package错误解决办法_大灰狼学编程的博客-CSDN博客 手把手教你…

栈的生长方向不总是向下

据我了解,栈的生长方向向下,内存地址由高到低 测试 windows下: 符合上述情况 测试Linux下: 由此可见,栈在不同操作系统环境下,生长方向不总是向下

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测

时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测 目录 时序预测 | MATLAB实现基于LSTM-AdaBoost长短期记忆网络结合AdaBoost时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 x 基本介绍 1.Matlab实现LSTM-Adaboost时间序列预测…

Android HAL学习 及 与BSP的区别

Android HAL学习 及 与BSP的区别 参考链接: 1、https://www.cnblogs.com/looner/articles/11579335.html 2、https://blog.csdn.net/leesan0802/article/details/124087630 3、https://zhuanlan.zhihu.com/p/336531442 在HAL的学习之前,我们来先了解…

京东数据分析(京东数据采集):2023年10月京东平板电视行业品牌销售排行榜

鲸参谋监测的京东平台10月份平板电视市场销售数据已出炉! 根据鲸参谋电商数据分析平台的相关数据显示,10月份,京东平台上平板电视的销量将近77万,环比增长约23%,同比则下降约30%;销售额为21亿,环…

数据库数据恢复—MongoDB数据库文件拷贝出现错误的数据恢复案例

MongoDB数据库数据恢复环境: 一台Windows Server操作系统的虚拟机,虚拟机上部署有MongoDB数据库。 MongoDB数据库故障&检测: 在未关闭MongoDB服务的情况下,工作人员将MongoDB数据库文件拷贝到其他分区,然后将原数…

双12电视盒子推荐:测评员解析目前电视盒子哪个最好

电视盒子不需要每月缴费,只需联网就可以收看海量视频资源,游戏、网课、投屏等功能让电视盒子的使用场景更丰富,我每年都会进行数十次电视盒子测评,本期要分享的是双十二电视盒子推荐,全面解析目前电视盒子哪个最好。 一…