【机器学习】梯度下降预测波士顿房价

文章目录

  • 前言
  • 一、数据集介绍
  • 二、预测房价代码
    • 1.引入库
    • 2.数据
    • 3.梯度下降
  • 总结


前言

梯度下降算法学习。

一、数据集介绍

波士顿房价数据集:波士顿房价数据集,用于线性回归预测

二、预测房价代码

1.引入库

from sklearn.linear_model import LinearRegression as LR
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston as boston 
import matplotlib.pyplot as plt
from sklearn.utils import shuffle
import numpy as np
from sklearn.metrics import mean_squared_error

2.数据

def preprocess():# get the dataset of bostonX = boston().datay = boston().targetname_data = boston().feature_names# draw the figure of relationship between feature and priceplt.figure(figsize=(20,20))for i in range(len(X[0])):plt.subplot(5, 3, i + 1)plt.scatter(X[:, i], y, s=20)plt.title(name_data[i])plt.show()# 删除相关性较低的特征# X = np.delete(X, [0, 1, 3, 4, 6, 7, 8, 9, 11], axis=1)# normalizationfor i in range(len(X[0])):X[:, i] = (X[:, i] - X[:, i].min()) / (X[:, i].max() - X[:, i].min())# split into test and trainXtrain, Xtest, Ytrain, Ytest = train_test_split(X, y, test_size=0.3, random_state=10)return Xtrain, Xtest, Ytrain, Ytest, X
def lr(Xtrain, Xtest, Ytrain, Ytest, if_figure):# use LinearRegressionreg = LR().fit(Xtrain, Ytrain)y_pred = reg.predict(Xtest)loss = mean_squared_error(Ytest, y_pred)print("*************LR*****************")print("w\t= {}".format(reg.coef_))print("b\t= {:.4f}".format(reg.intercept_))# draw the figure of predict resultsif if_figure:plt.figure(figsize = (14,6),dpi = 80)plt.plot(range(len(Ytest)), Ytest, c="blue", label="real")plt.plot(range(len(y_pred)), y_pred, c="red", linestyle=':', label="predict")plt.title("predict results from row LR")plt.legend()plt.show()return loss

3.梯度下降

def gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, rate):# 梯度下降def grad(y, yp, X):grad_w = (y - yp) * (-X)grad_b = (y - yp) * (-1)return [grad_w, grad_b]# 设置训练参数epoch_train = 100learning_rate = ratew = np.random.normal(0.0, 1.0, (1, len(X[0])))b = 0.0   loss_train = []loss_test = []for epoch in range(epoch_train + 1):loss1 = 0for i in range(len(Xtrain)):yp = w.dot(Xtrain[i]) + b# 计算损失err = Ytrain[i] - yploss1 += err ** 2# 迭代更新 w 和 bgw = grad(Ytrain[i], yp, Xtrain[i])[0]gb = grad(Ytrain[i], yp, Xtrain[i])[1]w = w - learning_rate * gwb = b - learning_rate * gb# 记录损失loss_train.append(loss1 / len(Xtrain))loss11 = 0for i in range(len(Xtest)):yp2 = w.dot(Xtest[i]) + berr2 = Ytest[i] - yp2loss11 += err2 ** 2# 记录损失loss_test.append(loss11 / len(Xtest))# shuffle the dataXtrain, Ytrain = shuffle(Xtrain, Ytrain)# draw the figure of lossif if_figure:plt.figure()plt.title("figure of loss")plt.plot(range(len(loss_train)), loss_train, c="blue", linestyle=":", label="train")plt.plot(range(len(loss_test)), loss_test, c="red", label="test")plt.legend()plt.show()# draw figure of predict resultsif if_figure:Predict_value = []for i in range(len(Xtest)):Predict_value.append(w.dot(Xtest[i]) + b)plt.figure()plt.title("predict results from gradScent")plt.plot(range(len(Xtest)), Ytest, c="blue", label="real")plt.plot(range(len(Xtest)), Predict_value, c="red", linestyle=':', label="predict")plt.legend()plt.show()return loss_test[-1], w, b
def test():if_figure = TrueXtrain, Xtest, Ytrain, Ytest, X = preprocess()loss_lr = lr(Xtrain, Xtest, Ytrain, Ytest, if_figure)loss_gd, w, b = gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, 0.01)print("*************GD*****************")      print("w\t: {}".format(w))print("b\t: {}".format(b))print("************loss****************")print("lr\t: %.4f" % loss_lr)print("gd\t: %.4f" % loss_gd)
def searchRate():if_figure = FalseXtrain, Xtest, Ytrain, Ytest, X = preprocess()loss_grad = []w_grad = []b_grad = []rates = list(np.arange(0.001, 0.05, 0.001))epoch = 1for rate in rates:loss, w, b = gradDescnet(Xtrain, Xtest, Ytrain, Ytest, X, if_figure, rate)loss_grad.append(loss[0])w_grad.append(w)b_grad.append(b)print("epoch %d: %.4f" % (epoch, loss_grad[-1]))epoch += 1plt.figure()plt.plot(rates, loss_grad)plt.title("loss under different rate")plt.show()loss_grad_min = min(loss_grad)position = loss_grad.index(loss_grad_min)w = w_grad[position]b = b_grad[position]rate = rates[position]loss_lr = lr(Xtrain, Xtest, Ytrain, Ytest, if_figure)print("*************GD*****************")print("w\t: {}".format(w))print("b\t: {}".format(b))print("rate: %.3f" % rate)print("************loss****************")print("lr\t: %.4f" % loss_lr)print("gd\t: %.4f" % loss_grad_min)

data = boston
Xtrain, Xtest, Ytrain, Ytest, X = preprocess()

在这里插入图片描述

lr(Xtrain, Xtest, Ytrain, Ytest,True)

在这里插入图片描述

test()

在这里插入图片描述在这里插入图片描述

searchRate()

在这里插入图片描述
在这里插入图片描述

总结

通过此次学习,对梯度下降算法有了更深的认识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/186372.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python爬虫实战-批量爬取美女图片网下载图片

大家好,我是python222小锋老师。 近日锋哥又卷了一波Python实战课程-批量爬取美女图片网下载图片,主要是巩固下Python爬虫基础 视频版教程: Python爬虫实战-批量爬取美女图片网下载图片 视频教程_哔哩哔哩_bilibiliPython爬虫实战-批量爬取…

【Java】I/O流—缓冲流的基础入门和文件拷贝的实战应用

🌺个人主页:Dawn黎明开始 🎀系列专栏:Java ⭐每日一句:你能坚持到什么程度,决定你能达到什么高度 📢欢迎大家关注🔍点赞👍收藏⭐️留言📝 文章目录 一.&…

RapidSSL证书

RapidSSL是一家经验丰富的证书颁发机构,主要专注于提供标准和通配符SSL证书的域验证SSL证书。在2017年被DigicertCA收购后,RapidSSL改进了技术并开始使用现代基础设施。专注于为小型企业和网站提供基本安全解决方案的SSL加密。RapidSSL它具有强大的浏览器…

ZYNQ_project:key_led

条件里是十进制可以不加进制说明,编译器默认是10进制,其他进制要说明。 实验目标: 模块框图: 时序图: 代码: include "para.v"module key_filter (input wire …

python3.8.10虚拟环境安装talib总报平台不匹配

目录 环境: 需求: 问题: 概述 过程及解决 解决方案总结 环境: 操作系统:window10、64位 开发工具:pycharm python版本:python3.8.10 需求: 在python3.8.10的虚拟环境中安…

短短 45 分钟发布会,OpenAI 如何再次让 AI 圈一夜未眠

目录 前言 1. GPT-4 Turbo,更快,更省钱 2. GPT Store 来了! 3. 零代码创建 AI Agent 前言 对于 AI 行业从业者来说,刚刚可能是一夜未眠。 北京时间 11 月 7 日凌晨,美国人工智能公司 OpenAI 的开发者大会正式开…

HTTParty库数据抓取代码示例

使用HTTParty库的网络爬虫程序, ruby require httparty # 设置服务器 proxy_host proxy_port # 使用HTTParty库发送HTTP请求获取网页内容 response HTTParty.get(/, :proxy > { :host > proxy_host, :port > proxy_port }) # 打印获取的网页内容 …

亚马逊云科技产品测评』活动征文|通过使用Amazon Neptune来预测电影类型初体验

文章目录 福利来袭Amazon Neptune什么是图数据库为什么要使用图数据库什么是Amazon NeptuneNeptune 的特点 快速入门环境搭建notebook 图神经网络快速构建加载数据配置端点Gremlin 查询清理 删除环境S3 存储桶删除 授权声明:本篇文章授权活动官方亚马逊云科技文章转…

Sketch是什么软件,如何收费和获得免费版

Sketch软件为设计师构建了一个优秀的本地Mac应用程序。Sketch是整个设计过程的平台,通过基于Web的工具共享工作,获取反馈,测试原型,并将其移交给任何浏览器。Sketch软件的定价根据不同的许可类型和订阅计划而变化。本文从Sketch软…

LeetCode算法题解(回溯、难点)|LeetCode332. 重新安排行程

LeetCode332. 重新安排行程 题目链接:332. 重新安排行程 题目描述: 给你一份航线列表 tickets ,其中 tickets[i] [fromi, toi] 表示飞机出发和降落的机场地点。请你对该行程进行重新规划排序。 所有这些机票都属于一个从 JFK&#xff08…

密码学 - RSA签名算法

实验九 RSA签名算法- 一、实验目的 通过实验掌握GMP开源软件的用法,理解RSA数字签名算法,学会RSA数字签名算法程序设计,提高一般数字签名算法的设计能力。 二、实验要求 (1)基于GMP开源软件,实现RSA签名算法。 (2)要求有对应…

浅谈多回路电表在荷兰光伏系统配电项目中的应用

1.背景信息 Background: 随着全球化石能源(石油,煤炭)越来越接近枯竭,污染日趋严重,气候日益变暖等问题,全球多个国家和地区相继出台了法规政策,推动了光伏产业的发展。但是现有的光…

MySQL索引的数据结构

1. 索引及其优缺点 1.1 索引概述 MySQL官方对索引的定义为:索引(Index)是帮助MySQL高效获取数据的数据结构。 索引的本质:索引是数据结构。你可以简单理解为“排好序的快速查找数据结构”,满足特定查找算法。这些数据结…

合成数据在医疗保健行业的案例研究

从机器人辅助手术到医学成像技术,人工智能在医疗保健领域的应用正在迅速改变医疗保健行业,并改善服务成本和服务质量。例如,埃森哲表示,到 150 年,人工智能临床健康应用每年可以为美国医疗保健行业节省 2026 亿美元。 …

Spring RabbitMQ那些事(1-交换机配置消息发送订阅实操)

这里写目录标题 一、序言二、配置文件application.yml三、RabbitMQ交换机和队列配置1、定义4个队列2、定义Fanout交换机和队列绑定关系2、定义Direct交换机和队列绑定关系3、定义Topic交换机和队列绑定关系4、定义Header交换机和队列绑定关系 四、RabbitMQ消费者配置五、Rabbit…

各大电商平台关于预制菜品种酸菜鱼销售量

# 导入需要的包 library(rvest) # 用于网页抓取 library(tidyverse) # 用于数据处理 library(stringr) # 用于字符串处理# 设置代理信息 proxy_host <- "www.duoip.cn" proxy_port <- 8000# 设置要爬取的网页 url <- "https://jshk.com.cn/products/sa…

【正点原子STM32连载】 第四十九章 SD卡实验 摘自【正点原子】APM32F407最小系统板使用指南

1&#xff09;实验平台&#xff1a;正点原子stm32f103战舰开发板V4 2&#xff09;平台购买地址&#xff1a;https://detail.tmall.com/item.htm?id609294757420 3&#xff09;全套实验源码手册视频下载地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html## 第四…

Spring的循环依赖问题

文章目录 1.什么是循环依赖2.代码演示3.分析问题4.问题解决5.Spring循环依赖6. 疑问点6.1 为什么需要三级缓存6.2 没有三级缓存能解决吗&#xff1f;6.3 三级缓存分别什么作用 1.什么是循环依赖 上图是循环依赖的三种情况&#xff0c;虽然方式有点不一样&#xff0c;但是循环依…

Yolov8模型训练报错:torch.cuda.OutOfMemoryError

最近在使用自己的数据训练Yolov8模型的时候遇到了很多错误&#xff0c;下面将逐一解答。 问题报错 在训练过程中红字报错&#xff1a;torch.cuda.OutOfMemoryError: CUDA out of memory. 后面还会跟着一大段报错&#xff1a; Tried to allocate XXX MiB (GPU 0; XXX GiB to…

【云原生】使用nginx反向代理后台多服务器

背景 随着业务发展&#xff0c; 用户访问量激增&#xff0c;单台服务器已经无法满足现有的访问压力&#xff0c;研究后需要将后台服务从原来的单台升级为多台服务器&#xff0c;那么原来的访问方式无法满足&#xff0c;所以引入nginx来代理多台服务器&#xff0c;统一请求入口…