机器学习---使用 TensorFlow 构建神经网络模型预测波士顿房价和鸢尾花数据集分类

1. 预测波士顿房价

1.1 导包
from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport itertoolsimport pandas as pd
import tensorflow as tftf.logging.set_verbosity(tf.logging.INFO)

最后一行设置了TensorFlow日志的详细程度:

tf.logging.DEBUG:最详细的日志级别,用于记录调试信息。

tf.logging.INFO:用于记录一般的信息性消息,比如训练过程中的指标和进度。

tf.logging.WARN:用于记录警告消息,表示可能存在潜在问题,但不会导致程序终止。

tf.logging.ERROR:仅记录错误消息,表示程序遇到了错误并可能终止执行。

tf.logging.FATAL:记录严重错误消息,并终止程序的执行。

1.2 处理数据集
COLUMNS = ["crim", "zn", "indus", "nox", "rm", "age","dis", "tax", "ptratio", "medv"]
FEATURES = ["crim", "zn", "indus", "nox", "rm","age", "dis", "tax", "ptratio"]
LABEL = "medv"training_set = pd.read_csv("boston_train.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)
test_set = pd.read_csv("boston_test.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)
prediction_set = pd.read_csv("boston_predict.csv", skipinitialspace=True,skiprows=1, names=COLUMNS)

定义了一些列名和特征,并使用pd.read_csv函数读取了训练集、测试集和预测集的数据。

pd.read_csv函数来读取CSV文件,并将其转换为Pandas数据帧。

1.3  创建DNNRegressor对象
feature_cols = [tf.feature_column.numeric_column(k) for k in FEATURES]
regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols,hidden_units=[50,50,50],model_dir="./boston_model")

tf.feature_column.numeric_column函数用于创建一个表示数值特征的特征列。在这种情况下,它

会遍历FEATURES列表中的每个特征名称,并为每个特征创建一个数值特征列。

创建DNNRegressor对象的参数:

  feature_columns:这是包含特征列的列表,用于定义输入的特征。在这里,您传递了之前创建

feature_cols,它包含了用于模型训练的数值特征列。

  hidden_units:这是一个整数列表,用于定义隐藏层的结构。在这个例子中,您定义了一个具

有3个隐藏层的DNN模型,每个隐藏层都有50个神经元。

model_dir:这是模型保存的目录路径。在这里,您指定了"./boston_model"作为模型保存的目录。

1.4 创建输入函数
def get_input_fn(data_set, num_epochs=None, shuffle=True):return tf.estimator.inputs.pandas_input_fn(x=pd.DataFrame({k: data_set[k].values for k in FEATURES}),y = pd.Series(data_set[LABEL].values),num_epochs=num_epochs,shuffle=shuffle)

该输入函数将Pandas数据帧作为输入,并将其转换为TensorFlow的输入格式。具体而言,它将特

征数据集(由FEATURES列表指定的列)转换为x,将标签数据(由LABEL指定的列)转换为y

1.5 训练评估预测
regressor.train(input_fn=get_input_fn(training_set), steps=5000)
ev = regressor.evaluate(input_fn=get_input_fn(test_set, num_epochs=1, shuffle=False))
loss_score = ev["loss"]
print("Loss: {0:f}".format(loss_score))
y = regressor.predict(input_fn=get_input_fn(prediction_set, num_epochs=1, shuffle=False))
# .predict() returns an iterator of dicts; convert to a list and print
# predictions
predictions = list(p["predictions"] for p in itertools.islice(y, 6))
print("Predictions: {}".format(str(predictions)))

steps参数指定了训练的迭代步数,即模型将对训练数据执行多少次梯度下降更新。

使用get_input_fn获取输入函数,该函数将测试集(test_set)作为输入数据。num_epochs参数设

置为1,表示测试集只会被迭代一次,shuffle参数被设置为False,表示测试集不需要进行洗牌。

然后提取评估结果中的损失值(loss),并将其赋值给loss_score变量。

通过迭代预测结果的字典形式,将预测值提取出来,并将其存储在predictions列表中。

2. 鸢尾花数据集分类

import tensorflow as tf
import pandas as pdCOLUMN_NAMES = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth', 'Species']# Import training dataset
training_dataset = pd.read_csv('iris_training.csv', names=COLUMN_NAMES, header=0)
train_x = training_dataset.iloc[:, 0:4]
train_y = training_dataset.iloc[:, 4]# Import testing dataset
test_dataset = pd.read_csv('iris_test.csv', names=COLUMN_NAMES, header=0)
test_x = test_dataset.iloc[:, 0:4]
test_y = test_dataset.iloc[:, 4]# Setup feature columns
columns_feat = [tf.feature_column.numeric_column(key='SepalLength'),tf.feature_column.numeric_column(key='SepalWidth'),tf.feature_column.numeric_column(key='PetalLength'),tf.feature_column.numeric_column(key='PetalWidth')
]# Build Neural Network - Classifier
classifier = tf.estimator.DNNClassifier(feature_columns=columns_feat,# Two hidden layers of 10 nodes each.hidden_units=[10, 10],# The model is classifying 3 classesn_classes=3)# Define train function
def train_function(inputs, outputs, batch_size):dataset = tf.data.Dataset.from_tensor_slices((dict(inputs), outputs))dataset = dataset.shuffle(1000).repeat().batch(batch_size)return dataset.make_one_shot_iterator().get_next()# Train the Model.
classifier.train(input_fn=lambda:train_function(train_x, train_y, 100),steps=1000)# Define evaluation function
def evaluation_function(attributes, classes, batch_size):attributes=dict(attributes)if classes is None:inputs = attributeselse:inputs = (attributes, classes)dataset = tf.data.Dataset.from_tensor_slices(inputs)assert batch_size is not None, "batch_size must not be None"dataset = dataset.batch(batch_size)return dataset.make_one_shot_iterator().get_next()# Evaluate the model.
eval_result = classifier.evaluate(input_fn=lambda:evaluation_function(test_x, test_y, 100))print('\nAccuracy: {accuracy:0.3f}\n'.format(**eval_result))

首先导入所需的库,包括 TensorFlow 和 Pandas。然后,定义了一个包含特征列的列

表 columns_feat,用于描述输入数据的特征。接下来,通过 Pandas 读取训练集和测试集的数

据,并将其分为输入特征和输出类别。

然后,使用 tf.estimator.DNNClassifier 类构建了一个多层感知机神经网络分类器。该分类器具

有两个隐藏层,每个隐藏层包含10个节点,输出层用于分类3个类别的鸢尾花。

然后,定义了一个训练函数 train_function 和一个评估函数 evaluation_function,用于转换输

入数据并创建 TensorFlow 数据集。训练函数将训练数据转换为 Dataset 对象,并进行随机化、重

复和分批处理。评估函数将测试数据转换为 Dataset 对象,并进行分批处理。

最后,通过调用 classifier.train 方法来训练模型,使用训练函数作为输入函数,并指定训练步

数。然后,通过调用 classifier.evaluate 方法来评估模型的性能,使用评估函数作为输入函数,

并指定评估时的批大小。评估结果包括准确率,并通过 print 函数进行输出。

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/173952.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue中使用xlsx插件导出多sheet excel实现方法

安装xlsx,一定要注意版本: npm i xlsx0.17.0 -S package.json: {"name": "hello-world","version": "0.1.0","private": true,"scripts": {"serve": "vue-c…

ESM蛋白质语言模型系列

模型总览 第一篇《Biological structure and function emerge from scaling unsupervised learning to 250 million protein sequences 》ESM-1b 第二篇《MSA Transformer》在ESM-1b的基础上作出改进,将模型的输入从单一蛋白质序列改为MSA矩阵,并在Tran…

RK3568-适配at24c04模块

将at24c04模块连接到开发板i2c2总线上 i2ctool查看i2c2总线上都有哪些设备 UU表示设备地址的从设备被驱动占用,卸载对应的驱动后,UU就会变成从设备地址。at24c04模块设备地址 0x50和0x51是at24c04模块i2c芯片的设备地址。这个从芯片手册上也可以得知。A0 A1 A2表示的是模块对…

简单而高效:使用PHP爬虫从网易音乐获取音频的方法

概述 网易音乐是一个流行的在线音乐平台,提供了海量的音乐资源和服务。如果你想从网易音乐下载音频文件,你可能会遇到一些困难,因为网易音乐对其音频资源进行了加密和防盗链的处理。本文将介绍一种使用PHP爬虫从网易音乐获取音频的方法&…

Go学习第十六章——Gin文件上传与下载

Go web框架——Gin文件上传与下载 1. 文件上传1.1 入门案例(单文件)1.2 服务端保存文件的几种方式SaveUploadedFileCreateCopy 1.3 读取上传的文件1.4 多文件上传 2. 文件下载2.1 快速入门2.2 前后端模式下的文件下载2.3 中文乱码问题 1. 文件上传 1.1 …

lesson2(补充)关于>>运算符和<<运算符重载

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 前言&#xff1a; cout和cin我们在使用时需要包含iostream头文件&#xff0c;我们可以知道的是cout是写在ostream类里的&#xff0c;cin是写在istream类里的&#xff0c;他们都是定义出的对象&#xff0c;而<< 和 >…

M1安装OpenPLC Editor

下载OpenPLC Editor for macOS.zip文件后&#xff0c;使用tar -zvxf命令解压&#xff0c;然后将"OpenPLC Editor"拖入到"应用程序"文件夹 右键点击"OpenPLC Editor"&#xff0c;打开这个""文件&#xff0c;替换为以下内容 #!/bin/bash…

香港服务器如何做负载均衡?

​  在现代互联网时代&#xff0c;随着网站访问量的不断增加&#xff0c;服务器的负载也越来越重。为了提高网站的性能和可用性&#xff0c;负载均衡成为了一种常见的解决方案。 什么是负载均衡? 负载均衡是一种技术解决方案&#xff0c;用于在多个服务器之间分配负载&#…

搜维尔科技:【应用】配备MTi-3的轻便型ROV,在水下进行地理标记视觉检测

部署潜水员进行水下摄像&#xff0c;不仅难度高而且费用昂贵&#xff0c;需要受过潜水和摄像两方面培训的专业人员来进行。但有些水下作业任务例如拍摄海底管道内部的照片&#xff0c;由于人员无法进入或危险度高的原因&#xff0c;无法由潜水员完成。 如今&#xff0c;俄罗…

vue源码分析(五)——vue render 函数的使用

文章目录 前言一、render函数1、render函数是什么&#xff1f; 二、render 源码分析1.执行initRender方法2.vm._c 和 vm.$createElement 调用 createElement 方法详解&#xff08;1&#xff09;区别&#xff08;2&#xff09;代码 3、原型上的_render方法&#xff08;1&#xf…

JWT详解解读读

&#x1f4d1;前言 本文主要是jwt解读文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是青衿&#x1f947; ☁️博客首页&#xff1a;CSDN主页放风讲故事 &#x1f304;每日一句&#xff1a;努力一点&#…

数据结构-初识泛型

写在前&#xff1a; 这一篇博客主要来初步的记录以下泛型的相关内容&#xff0c;内容比较琐碎&#xff0c;就不进行目录的整合&#xff0c;后续可能会对泛型这里进行系统性的梳理&#xff0c;此篇博客主要是对泛型有一个简单的认识与理解&#xff0c;需要知晓的内容。 当我调用…

七层负载均衡 HAproxy

一、HAproxy 1、负载均衡类型&#xff1a; (1) 无负载均衡&#xff1a; 没有负载均衡&#xff0c;用户直接连接到 Web 服务器。当许多用户同时访问服务器时&#xff0c;可能无法连接。 (2) 四层负载均衡&#xff1a; 用户访问负载均衡器&#xff0c;负载均衡器将用户的请求…

听GPT 讲Rust源代码--library/std(8)

题图来自Why is Rust programming language so popular?[1] File: rust/library/std/src/sys/sgx/abi/reloc.rs 在Rust源代码中&#xff0c;sgx/abi/reloc.rs文件的作用是定义了针对Intel Software Guard Extensions (SGX)的重定位相关结构和函数。 该文件中的Rela 结构定义了…

java之输入与输出的详细介绍

文章目录 输出的相关格式使用 Scanner 类进行控制台输入步骤&#xff1a;示例&#xff1a; 如何格式化输出&#xff1f;1. 使用 System.out.printf2. 使用 String.format printf与println 的区别printfprintln主要区别&#xff1a; 输出的相关格式 控制台输入是指通过命令行或…

C++ Qt/VTK装配体组成联动连接杆

效果 关键代码 #include "View3D.h" #include "Axis.h"#include <vtkActor.h> #include <vtkAppendPolyData.h > #include <vtkAreaPicker.h> #include <vtkAxesActor.h> #include <vtkBox.h> #include <vtkCamera.h>…

开源3D激光(视觉)SLAM算法汇总(持续更新)

原文连接 目录 一、Cartographer 二、hdl_graph_slam 三、LOAM 四、LeGO-LOAM 五、LIO-SAM 六、S-LOAM 七、M-LOAM 八、livox-loam 九、Livox-Mapping 十、LIO-Livox 十一、FAST-LIO2 十二、LVI-SAM 十三、FAST-Livo 十四、R3LIVE 十五、ImMesh 十六、Point-LIO 一、Cartograph…

辅助驾驶功能开发-功能规范篇(22)-3-L2级辅助驾驶方案功能规范

1.3.3 TLA系统功能定义 1.3.3.1 状态机 1.3.3.2 状态迁移图 1.3.3.3 功能定义 1.3.3.3.1 信号需求列表 1.3.3.3.2 系统开启关闭 1)初始化 车辆上电后,交通灯辅助系统(TLA)进行初始化,控制器需在 220ms 内发出第一帧报文,并在 3s 内完成内部自检,同时上电 3s 内不进行…

VR数字党建:红色文化展厅和爱国主义教育线上线下联动

伴随着党建思想的加深&#xff0c;很多政府单位都有打造VR党建展厅的想法&#xff0c;而党建基地也是激发爱国热情、凝聚人民力量、培养民族精神的重要场所。现如今&#xff0c;伴随着5G、VR等技术的成熟&#xff0c;VR数字党建积极推动运用VR技术&#xff0c;推动红色文化展厅…

PyTorch中grid_sample的使用方法

官方文档首先Pytorch中grid_sample函数的接口声明如下&#xff1a; torch.nn.functional.grid_sample(input, grid, modebilinear, padding_modezeros, align_cornersNone)input : 输入tensor&#xff0c; shape为 [N, C, H_in, W_in]grid: 一个field flow&#xff0c; shape为…