基于MLP完成CIFAR-10数据集和UCI wine数据集的分类

基于MLP完成CIFAR-10数据集和UCI wine数据集的分类,使用到了sklearn和tensorflow,并对图片分类进行了数据可视化展示

数据集介绍

UCI wine数据集:

http://archive.ics.uci.edu/dataset/109/wine

这些数据是对意大利同一地区种植的葡萄酒进行化学分析的结果,但来自三个不同的品种。该分析确定了三种葡萄酒中每一种中发现的13种成分的数量。

CIFAR-10数据集:

https://www.cs.toronto.edu/~kriz/cifar.html

CIFAR-10 数据集由 10 类 60000 张 32x32 彩色图像组成,每类 6000 张图像。有 50000 张训练图像和 10000 张测试图像。

数据集分为 5 个训练批次和 1 个测试批次,每个批次有 10000 张图像。测试批次正好包含从每个类中随机选择的 1000 张图像。训练批次以随机顺序包含剩余的图像,但某些训练批次可能包含来自一个类的图像多于另一个类。在它们之间,训练批次正好包含来自每个类的 5000 张图像

MLP算法

MLP 代表多层感知器(Multilayer Perceptron),是一种基本的前馈神经网络(Feedforward Neural Network)模型。它由一个输入层、一个或多个隐藏层和一个输出层组成,其中每个层都包含多个神经元(或称为节点)。MLP 是一种强大的模型,常用于解决分类和回归问题。

MLP 的基本组成部分如下:

  • 输入层(Input Layer): 接收原始数据的输入层,每个输入节点对应输入特征。

  • 隐藏层(Hidden Layer):
    位于输入层和输出层之间的一层或多层神经元。每个神经元通过权重与前一层的所有节点相连接,并通过激活函数进行非线性变换。隐藏层的存在使得网络能够学习输入数据的复杂特征。

  • 输出层(Output Layer): 提供最终的网络输出。对于不同的问题,输出层的激活函数可能不同。例如,对于二分类问题,可以使用
    sigmoid 激活函数;对于多分类问题,可以使用 softmax 激活函数。

模型构建

UCI wine:

我们加载 sklearn.datasets 中的 load_wine作为训练数据,划分为数据集和测试集,并进行标准化操作

接着调用 MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42) 创建模型

训练后在测试集上预测,最后评估模型
在这里插入图片描述

from sklearn.neural_network import MLPClassifier
from sklearn.datasets import load_iris
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler# 加载Iris数据集
# iris = load_iris()
# X = iris.data
# y = iris.targetwine = load_wine()
X = wine.data
y = wine.target# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 数据标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)# 构建MLP模型
mlp = MLPClassifier(hidden_layer_sizes=(100,), max_iter=1000, random_state=42)# 训练模型
mlp.fit(X_train_scaled, y_train)# 在测试集上进行预测
y_pred = mlp.predict(X_test_scaled)# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)# 打印结果
print("Accuracy:", accuracy)
print("\nConfusion Matrix:\n", conf_matrix)
print("\nClassification Report:\n", class_report)

CIFAR-10:

我们使用 tf.keras.datasets.cifar10中自带的数据进行训练

使用 tf.keras.Sequential() 这个函数创建模型,设置四层网络

接着对代码进行批量训练,评估和保留模型后对结果进行可视化处理

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

########cifar10数据集##########
###########保存模型############
########卷积神经网络##########
#train_x:(50000, 32, 32, 3), train_y:(50000, 1), test_x:(10000, 32, 32, 3), test_y:(10000, 1)
#60000条训练数据和10000条测试数据,32x32像素的RGB图像
#第一层两个卷积层16个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#第二层两个卷积层32个3*3卷积核,一个池化层:最大池化法2*2卷积核,激活函数:ReLU
#隐含层激活函数:ReLU函数
#输出层激活函数:softmax函数(实现多分类)
#损失函数:稀疏交叉熵损失函数
#隐含层有128个神经元,输出层有10个节点
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as npimport time
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print(nowtime)#指定GPU
#import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# gpus = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(gpus[0],True)
#初始化
plt.rcParams['font.sans-serif'] = ['SimHei']#加载数据
cifar10 = tf.keras.datasets.cifar10
(train_x,train_y),(test_x,test_y) = cifar10.load_data()
print('\n train_x:%s, train_y:%s, test_x:%s, test_y:%s'%(train_x.shape,train_y.shape,test_x.shape,test_y.shape))#数据预处理
X_train,X_test = tf.cast(train_x/255.0,tf.float32),tf.cast(test_x/255.0,tf.float32)     #归一化
y_train,y_test = tf.cast(train_y,tf.int16),tf.cast(test_y,tf.int16)#建立模型
model = tf.keras.Sequential()
##特征提取阶段
#第一层
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu,data_format='channels_last',input_shape=X_train.shape[1:]))  #卷积层,16个卷积核,大小(3,3),保持原图像大小,relu激活函数,输入形状(28,28,1)
model.add(tf.keras.layers.Conv2D(16,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))   #池化层,最大值池化,卷积核(2,2)
#第二层
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.Conv2D(32,kernel_size=(3,3),padding='same',activation=tf.nn.relu))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2,2)))
##分类识别阶段
#第三层
model.add(tf.keras.layers.Flatten())    #改变输入形状
#第四层
model.add(tf.keras.layers.Dense(128,activation='relu'))     #全连接网络层,128个神经元,relu激活函数
model.add(tf.keras.layers.Dense(10,activation='softmax'))   #输出层,10个节点
print(model.summary())      #查看网络结构和参数信息#配置模型训练方法
#adam算法参数采用keras默认的公开参数,损失函数采用稀疏交叉熵损失函数,准确率采用稀疏分类准确率函数
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])#训练模型
#批量训练大小为64,迭代5次,测试集比例0.2(48000条训练集数据,12000条测试集数据)
print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练前时刻:'+str(nowtime))history = model.fit(X_train,y_train,batch_size=64,epochs=5,validation_split=0.2)print('--------------')
nowtime = time.strftime('%Y-%m-%d %H:%M:%S')
print('训练后时刻:'+str(nowtime))#评估模型
model.evaluate(X_test,y_test,verbose=2)     #每次迭代输出一条记录,来评价该模型是否有比较好的泛化能力#保存整个模型
model.save('CIFAR10_CNN_weights.h5')#结果可视化
print(history.history)
loss = history.history['loss']          #训练集损失
val_loss = history.history['val_loss']  #测试集损失
acc = history.history['sparse_categorical_accuracy']            #训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    #测试集准确率plt.figure(figsize=(10,3))plt.subplot(121)
plt.plot(loss,color='b',label='train')
plt.plot(val_loss,color='r',label='test')
plt.ylabel('loss')
plt.legend()plt.subplot(122)
plt.plot(acc,color='b',label='train')
plt.plot(val_acc,color='r',label='test')
plt.ylabel('Accuracy')
plt.legend()#暂停5秒关闭画布,否则画布一直打开的同时,会持续占用GPU内存
#根据需要自行选择
#plt.ion()       #打开交互式操作模式
#plt.show()
#plt.pause(5)
#plt.close()#使用模型
plt.figure()
for i in range(10):num = np.random.randint(1,10000)plt.subplot(2,5,i+1)plt.axis('off')plt.imshow(test_x[num],cmap='gray')demo = tf.reshape(X_test[num],(1,32,32,3))y_pred = np.argmax(model.predict(demo))plt.title('标签值:'+str(test_y[num])+'\n预测值:'+str(y_pred))
#y_pred = np.argmax(model.predict(X_test[0:5]),axis=1)
#print('X_test[0:5]: %s'%(X_test[0:5].shape))
#print('y_pred: %s'%(y_pred))#plt.ion()       #打开交互式操作模式
plt.show()
#plt.pause(5)
#plt.close()

项目地址

https://gitee.com/yishangyishang/homeword.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/222016.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux调试器gdb的用法

Linux调试器gdb的用法 1. debug/release版本之间的比较2. gdb调试器的基本指令3. 使用展示 1. debug/release版本之间的比较 在之前学习C语言的的时候出过一期vs的调试技巧。 而对于现在的Linux下的调试器gdb其实也是换汤不换药的,基本上的调试思路是不会改变的&am…

http -- 跨域问题详解(浏览器)

参考链接 参考链接 1. 跨域报错示例 Access to XMLHttpRequest at http://127.0.0.1:3000/ from origin http://localhost:3000 has been blocked by CORS policy: Response to preflight request doesnt pass access control check: No Access-Control-Allow-Origin header…

【Java 集合】LinkedBlockingDeque

在开始介绍 LinkedBlockingDeque 之前, 我们先看一下 LinkedBlockingDeque 的类图: 从其中可以看出他直接实现了 BlockingDeque 接口, 而 BlockingDeque 又实现了 BlockingQueue 的接口, 所以它本身具备了队列的特性。 而实现 BlockingDeque 使其在 BlockingQueue 的基础上多了…

Spring Boot自动装配原理以及实践

了解自动装配两个核心 Import注解的作用 Import说Spring框架经常会看到的注解,它有以下几个作用: 导入Configuration类下所有的bean方法中创建的bean。导入import指定的bean,例如Import(AService.class),就会生成AService的bean&#xff0…

获取请求体中json数据并解析到实体对象

目录 相关依赖 前端代码 后端代码 测试结果 相关依赖 <dependency><groupId>com.alibaba</groupId><artifactId>fastjson</artifactId><version>1.2.83</version> </dependency> <dependency><groupId>comm…

02 ModBus TCP

目录 一、ModBus TCP 一帧数据格式 二、0x01 读线圈状态 三、0x03读保持寄存器 四、0x05写单个线圈 五、0x06 写单个寄存器 六、0x0f写多个线圈 七、0x10&#xff1a;写多个保持寄存器 八、通信过程 九、不同modbus通信模式的应用场景 一、ModBus TCP 一帧数据格式 其…

fill-in-the-middle(FIM) 实现与简单应用

1 背景 传统训练的 GPT 模型只能根据前文内容预测后文内容&#xff0c;但有些应用比如代码生成器&#xff0c;需要我们给出上文和下文&#xff0c;使模型可以预测中间的内容&#xff0c;传统训练的 GPT 就不能完成这类任务。 传统训练的 GPT 只能根据上文预测下文 使用 FIM…

基于Pytest+Requests+Allure实现接口自动化测试

一、整体结构 框架组成&#xff1a;pytestrequestsallure 设计模式&#xff1a; 关键字驱动 项目结构&#xff1a; 工具层&#xff1a;api_keyword/ 参数层&#xff1a;params/ 用例层&#xff1a;case/ 数据驱动&#xff1a;data_driver/ 数据层&#xff1a;data/ 逻…

玩转大数据19:数据治理与元数据管理策略

随着大数据时代的到来&#xff0c;数据已经成为企业的重要资产。然而&#xff0c;如何有效地管理和利用这些数据&#xff0c;成为了一个亟待解决的问题。数据治理和元数据管理是解决这个问题的关键。 1.数据治理的概念和重要性 数据治理是指对数据进行全面、系统、规范的管理…

MLOps在极狐GitLab 的现状和前瞻

什么是 MLOps 首先我们可以这么定义机器学习&#xff08;Machine Learning&#xff09;&#xff1a;通过一组工具和算法&#xff0c;从给定数据集中提取信息以进行具有一定程度不确定性的预测&#xff0c;借助于这些预测增强用户体验或推动内部决策。 同一般的软件研发流程比…

行为型设计模式(一)模版方法模式 迭代器模式

模板方法模式 Template 1、什么是模版方法模式 模版方法模式定义了一个算法的骨架&#xff0c;它将其中一些步骤的实现推迟到子类里面&#xff0c;使得子类可以在不改变算法结构的情况下重新定义算法中的某些步骤。 2、为什么使用模版方法模式 封装不变部分&#xff1a;模版…

vscode配置node.js调试环境

node.js基于VSCode的开发环境的搭建非常简单。 说明&#xff1a;本文的前置条件是已安装好node.js(具体安装不再赘述&#xff0c;如有需要可评论区留言)。 阅读本文可掌握&#xff1a; 方便地进行js单步调试&#xff1b;方便地查看内置的对象或属性&#xff1b; 安装插件 C…

RouterSrv-DHCP

2023年全国网络系统管理赛项真题 模块B-Windows解析 题目 安装和配置DHCP relay服务,为办公区域网络提供地址上网。DHCP服务器位于AppSrv服务器上。拆分DHCP服务器上的作用域,拆分的百分比为7:3。InsideCli优先从RouterSrv获取地址。配置步骤 安装和配置DHCP relay服务,为办…

AIGC:阿里开源大模型通义千问部署与实战

1 引言 通义千问-7B&#xff08;Qwen-7B&#xff09;是阿里云研发的通义千问大模型系列的70亿参数规模的模型。Qwen-7B是基于Transformer的大语言模型, 在超大规模的预训练数据上进行训练得到。预训练数据类型多样&#xff0c;覆盖广泛&#xff0c;包括大量网络文本、专业书籍…

云原生消息流系统 Apache Pulsar 在腾讯云的大规模生产实践

导语 由 InfoQ 主办的 Qcon 全球软件开发者大会北京站上周已精彩落幕&#xff0c;腾讯云中间件团队的冉小龙参与了《云原生机构设计与音视频技术应用》专题&#xff0c;带来了以《云原生消息流系统 Apache Pulsar 在腾讯云的大规模生产实践》为主题的精彩演讲&#xff0c;在本…

Linux shell编程学习笔记37:readarray命令和mapfile命令

目录 0 前言1 readarray命令的格式和功能 1.1 命令格式1.2 命令功能1.3 注意事项2 命令应用实例 2.1 从标准输入读取数据时不指定数组名&#xff0c;则数据会保存到MAPFILE数组中2.2 从标准输入读取数据并存储到指定的数组2.3 使用 -O 选项指定起始下标2.4 用-n指定有效行数…

21.Servlet 技术

JavaWeb应用的概念 在Sun的Java Servlet规范中&#xff0c;对Java Web应用作了这样定义&#xff1a;“Java Web应用由一组Servlet、HTML页、类、以及其它可以被绑定的资源构成。它可以在各种供应商提供的实现Servlet规范的 Servlet容器 中运行。” Java Web应用中可以包含如下…

人工智能的发展之路:时间节点、问题与解决办法的全景解析

导言 人工智能的发展历程充满了里程碑式的事件&#xff0c;从早期的概念到今天的广泛应用&#xff0c;每个时间节点都伴随着独特的挑战和创新。本文将详细描述每个关键时间节点的事件&#xff0c;探讨存在的问题、解决办法&#xff0c;以及不同阶段之间的联系。 1. 195…

重温经典struts1之自定义转换器及注册的两种方式(Servlet,PlugIn)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 前言 Struts的ActionServlet接收用户在浏览器发送的请求&#xff0c;并将用户输入的数据&#xff0c;按照FormBean中定义的数据类型&#xff0c;赋值给FormBean中每个变量&a…

Databend 源码阅读: Meta-service 数据结构

作者&#xff1a;张炎泼&#xff08;XP&#xff09; Databend Labs 成员&#xff0c;Databend 分布式研发负责人 drmingdrmer (张炎泼) GitHub 引言 Databend 是一款开源的云原生数据库&#xff0c;采用 Rust 语言开发&#xff0c;专为云原生数据仓库的需求而设计。 面向云架…