python实现——分类类型数据挖掘任务(图形识别分类任务)

  1. 分类类型数据挖掘任务

基于卷积神经网络(CNN)的岩石图像分类。有一岩石图片数据集,共300张岩石图片,图片尺寸224x224。岩石种类有砾岩(Conglomerate)、安山岩(Andesite)、花岗岩(Granite)、石灰岩(Limestone)、石英岩(Quartzite)和5种,每种岩石图片各50张,共250张。请选择合适模型对该数据集进行建模,训练优化模型并给出模型评估指标,再利用GUI框架开发岩石图片分类界面。

1.1总体流程

1.2数据增强

定义:数据增强是利用现有数据生成新的数据来增加数据量的过程,能够有效地扩充训练数据集的大小,提高模型的泛化能力,同时也能够有效地防止过拟合现象的发生。

本项目采用的数据增强方法:

(1)水平翻转

(2)缩放

(3)旋转

(4)添加高斯噪音

(5)调整对比度和亮度

通过数据增强,数据集从之前的250张扩充至1500张,数据量为之前的6倍。

参考代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)

 结果:

1.3数据预处理

将1500张图片依次读入并转化为可训练的数据(特征变量(X)和标签(Y))

代码:

import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)

1.4模型构建

1.4.1模型结构定义

模型参数:

参考代码:

from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  tf.keras.layers.Flatten(),  tf.keras.layers.Dense(84, activation='relu'),  tf.keras.layers.Dropout(0.3),  tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  # 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  metrics=['sparse_categorical_accuracy'])  # 监控准确率  # 打印模型概述  
model.summary()  # 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  

 

1.4.2模型译

编译参数参考:

# 优化器optimizer='adam'# 损失函数loss='sparse_categorical_crossentropy'# 评估指标metrics=['sparse_categorical_accuracy']

1.5模型训练

1.5.1划分训练集和测试集

按照训练集:测试集=8:2的比例对数据集进行划分,建议使用sklearn库中的train_test_split函数。

1.5.2训练

使用fit函数对训练集进行拟合训练,并将训练过程中产生的历史数据history保存至变量中。

训练参数参考:

# 迭代次数epochs=20# 验证集比例validation_split=0.2

1.5.3训练过程可视化

对history中保存下来的训练过程中的loss和sparse_categorical_accuracy的变化情况进行绘图。

参考代码:

# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  # 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show

 

 

1.6.3保存模型

使用save函数对训练好的模型进行保存,方便后续使用。

参考代码:

model.save('roch_classification_cnn.h5')

1.7图形用户界面(GUI)开发

1.7.1配置开发工具

在PyCharm中配置QtDesigner和PyUIC工具。

注意:需提前在python环境中安装好PyQt5和PyQt5-tools库。

  1. 配置QtDesigner

Program:(对应designer.exe的路径)

Working directory: $FileDir$

  1. 配置PyUCI

Program:(对应pyuic5.exe的路径)

Arguments: $FileName$ -o $FileNameWithoutExtension$.py

Working directory: $FileDir$

配置完成后的界面:

1.7.2设计图形用户界面

在PyCharm中“Tools”—“External Tools”中打开QtDesigner

在QtDesigner主界面中选择创建Main Window,然后根据需求选择相应的控件进行设计。

设计界面参考:

设计好之后保存为.ui文件。

1.7.3 ui文件转换为代码

在PyCharm中右键点击.ui文件并使用PyUCI工具进行转换。

1.7.4代码与模型结合

将转化后的代码与之前训练的模型相结合。

参考代码:

# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path)  # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())

1.7.5测试

执行程序测试“导入图片”和“鉴定分类”功能。

1.8打包可执行文件(exe)

在命令窗口中使用如下指令对上一步的程序进行打包。

Pyinstaller -F -w xxxxx.py

运行生成的.exe文件并测试功能。

打完包之后可能出现错误

报错信息:

=============================================================

A RecursionError (maximum recursion depth exceeded) occurred.

For working around please follow these instructions

=============================================================

1. In your program's .spec file add this line near the top::

     import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5)

2. Build your program by running PyInstaller with the .spec file as

   argument::

     pyinstaller myprog.spec

3. If this fails, you most probably hit an endless recursion in

   PyInstaller. Please try to track this down has far as possible,

   create a minimal example so we can reproduce and open an issue at

   https://github.com/pyinstaller/pyinstaller/issues following the

   instructions in the issue template. Many thanks.

Explanation: Python's stack-limit is a safety-belt against endless recursion,

eating up memory. PyInstaller imports modules recursively. If the structure

how modules are imported within your program is awkward, this leads to the

nesting being too deep and hitting Python's stack-limit.

With the default recursion limit (1000), the recursion error occurs at about

115 nested imported, with limit 2000 at about 240, with limit 5000 at about

660.

————————————————

你打包目录下会生成如下文件

打开你的main.spec文件

在顶端添加代码:

import sys

sys.setrecursionlimit(sys.getrecursionlimit() * 5)

然后在运行命令(对应的文件名)

pyinstaller 你的文件名.spec

然后就完成了

打完包之的运行闪退问题:

先安装一个新的第三方库ordereddict

安装命令:

pip install ordereddict

注意自己python代码的文件引入路径(确保对应的路径下有对应的文件,我这里设置的是根目录下)

重新打包

完成之后

打开对应的文件夹双击就可以了

完整代码:

import cv2
import os
import glob
# 数据增强函数
def augment_data(img, save_path):rows, cols, _ = img.shape# 水平翻转图像img_flip = cv2.flip(img, 1)img_name = os.path.splitext(save_path)[0] + "_flip.jpg"cv2.imwrite(img_name, img_flip)print("Saved augmented image:", img_name)# 随机缩放图像scale = np.random.uniform(0.9, 1.1)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 0, scale)img_transformed = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_transform.jpg"cv2.imwrite(img_name, img_transformed)print("Saved augmented image:", img_name)# 随机旋转图像angle = np.random.randint(-10, 10)M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)img_rotated = cv2.warpAffine(img, M, (cols, rows))img_name = os.path.splitext(save_path)[0] + "_rotated.jpg"cv2.imwrite(img_name, img_rotated)print("Saved augmented image:", img_name)# 添加高斯噪音mean = 0std = np.random.uniform(5, 15)noise = np.zeros(img.shape, np.float32)cv2.randn(noise, mean, std)noise = np.uint8(noise)img_noisy = cv2.add(img, noise)img_name = os.path.splitext(save_path)[0] + "_noisy.jpg"cv2.imwrite(img_name, img_noisy)print("Saved augmented image:", img_name)# 随机调整对比度和亮度alpha = np.random.uniform(0.8, 1.2)beta = np.random.randint(-10, 10)img_contrast = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)img_name = os.path.splitext(save_path)[0] + "_contrast.jpg"cv2.imwrite(img_name, img_contrast)print("Saved augmented image:", img_name)return img
# 读取 data 文件夹中的所有图片,并进行数据增强
data_dir = r"images"
save_dir = r"images2"
if not os.path.exists(save_dir):os.makedirs(save_dir)
# 使用 glob 库来遍历 data 文件夹中所有图像
for img_path in glob.glob(os.path.join(data_dir, "*.jpg")):img = cv2.imread(img_path)if img is None:print("Error: Unable to read image at", img_path)continue# 获取保存增强后的图片文件名img_name = os.path.basename(img_path)save_path = os.path.join(save_dir, img_name)# 数据增强augmented_img = augment_data(img, save_path)if augmented_img is not None:# 保存原始图片cv2.imwrite(save_path, img)print("Saved original image:", save_path)
#%%
import os
import cv2
import numpy as np
from PIL import Image
# 设置图片文件夹路径
image_folder = r"images2"
# 获取所有类别的文件夹名(假设每个文件夹是一个类别)
categories = os.listdir(image_folder)# 初始化特征变量 X 和标签 Y 的列表
X_list = np.zeros((len(categories), 224, 224, 3))
Y_list = np.zeros((len(categories)))i=0
for name in categories:img = Image.open(image_folder + '\\' +name)img_rgb = img.split()X_list[i,:,:,0] = np.array(img_rgb[0])/255X_list[i,:,:,1] = np.array(img_rgb[1])/255X_list[i,:,:,2] = np.array(img_rgb[2])/255Y_list[i] = name.split('_')[0]i+=1
# 将特征变量 X 和标签 Y 的列表转化为 NumPy 数组
X = np.array(X_list)
Y = np.array(Y_list)# 打印特征变量 X 和标签 Y 的形状
print('特征变量 X 的形状:', X)
print('标签 Y 的形状:', Y)
#%%
from sklearn.model_selection import train_test_split
import seaborn as sns  
import matplotlib.pyplot as plt  
import tensorflow as tf
from sklearn.metrics import confusion_matrix  
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
# 5个类别  
num_classes = 5  
# 输入图像的大小是224x224,有3个颜色通道(对于彩色图像)  
input_shape = (224, 224, 3)  
# 假设X和Y是您的原始数据  
# X: 图像数据,形状为(num_samples, 224, 224, 3)  
# Y: 标签数据,形状为(num_samples,) 并且是整数形式的标签(从0到4)  
# 将数据划分为训练集和测试集(只执行一次)  
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42)  
# 构建模型  
model = tf.keras.models.Sequential([  tf.keras.layers.Conv2D(6, (5, 5), strides=(1,1), activation='relu', input_shape=input_shape),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(16, (5,5), activation='relu'),  tf.keras.layers.MaxPooling2D((2,2), strides=2),  tf.keras.layers.Conv2D(120, (5,5), activation='relu'),  tf.keras.layers.Flatten(),  tf.keras.layers.Dense(84, activation='relu'),  tf.keras.layers.Dropout(0.3),  tf.keras.layers.Dense(num_classes, activation='softmax')  # 确保输出层的神经元数量与类别数量匹配  
])  # 编译模型  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),# 使用sparse categorical crossentropy损失函数   optimizer=tf.keras.optimizers.Adam(),  # 使用Adam优化器  metrics=['sparse_categorical_accuracy'])  # 监控准确率  # 打印模型概述  
model.summary()  # 使用model.fit()函数训练模型  
history = model.fit(x_train, y_train, epochs=10, validation_split=0.2)  #%%
y_pred = model.predict(x_test) 
print(y_pred)
#%%#%%
# 获取训练和验证的准确率和损失  
acc = history.history['sparse_categorical_accuracy']  
val_acc = history.history['sparse_categorical_accuracy']  
loss = history.history['loss']  
val_loss = history.history['val_loss']  # 使用model.evaluate()函数评估模型在测试集上的性能  
test_loss, test_accuracy = model.evaluate(x_test, y_test)  
print(f'Test accuracy: {test_accuracy}')  # 使用model.predict()函数对新的图像进行预测。
plt.figure(figsize=(15,10))
plt.plot(history.epoch, history.history['loss'],label='loss')
plt.plot(history.epoch, history.history['val_loss'],label='var_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend(loc='upper right')plt.figure(figsize=(15,10))
plt.plot(history.epoch,history.history['sparse_categorical_accuracy'],label='sparse_categorical_accuracy')
plt.plot(history.epoch,history.history['val_sparse_categorical_accuracy'],label='val_sparse_categorical_accuracy')
plt.xlabel('Epoch')
plt.ylabel('sparse_categorical_accuracy')
plt.legend(loc='upper right')
plt.show()plt.rcParams['font.sans-serif'] = ['SimHei'] 
y_pred = np.argmax(model.predict(x_test),axis=1)
cm = confusion_matrix(y_test, y_pred,labels=[0,1,2,3,4])
sns.heatmap(cm,annot=True,cmap="Blues",cbar=False,linewidths=2,linecolor='white',square=True,xticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'],yticklabels=['砾岩','安山岩','花岗岩','石灰岩','石英岩'])
plt.show
#%%
model.save('roch_classification_cnn.h5')

# -*- coding: utf-8 -*-
import osfrom PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
import tensorflow as tf
from PIL import Image
import numpy as np
import sys
class Ui_MainWindow(object):def setupUi(self, MainWindow):MainWindow.setObjectName("MainWindow")MainWindow.resize(800, 600)self.centralwidget = QtWidgets.QWidget(MainWindow)self.centralwidget.setObjectName("centralwidget")self.label = QtWidgets.QLabel(self.centralwidget)self.label.setGeometry(QtCore.QRect(220, 20, 291, 61))self.label.setScaledContents(False)self.label.setObjectName("label")self.pushButton = QtWidgets.QPushButton(self.centralwidget)self.pushButton.setGeometry(QtCore.QRect(160, 430, 93, 28))self.pushButton.setObjectName("pushButton")self.pushButton_2 = QtWidgets.QPushButton(self.centralwidget)self.pushButton_2.setGeometry(QtCore.QRect(440, 430, 93, 28))self.pushButton_2.setObjectName("pushButton_2")self.label_2 = QtWidgets.QLabel(self.centralwidget)self.label_2.setGeometry(QtCore.QRect(150, 90, 381, 321))self.label_2.setText("")self.label_2.setObjectName("label_2")self.label_3 = QtWidgets.QLabel(self.centralwidget)self.label_3.setGeometry(QtCore.QRect(550, 130, 141, 51))self.label_3.setText("")self.label_3.setObjectName("label_3")self.label_4 = QtWidgets.QLabel(self.centralwidget)self.label_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.label_4.setObjectName("label_4")self.textBrowser = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser.setGeometry(QtCore.QRect(150, 90, 381, 321))self.textBrowser.setObjectName("textBrowser")self.textBrowser_2 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_2.setGeometry(QtCore.QRect(550, 130, 141, 51))self.textBrowser_2.setObjectName("textBrowser_2")self.textBrowser_3 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_3.setGeometry(QtCore.QRect(220, 20, 291, 61))self.textBrowser_3.setObjectName("textBrowser_3")self.textBrowser_4 = QtWidgets.QTextBrowser(self.centralwidget)self.textBrowser_4.setGeometry(QtCore.QRect(550, 90, 141, 31))self.textBrowser_4.setObjectName("textBrowser_4")self.textBrowser_2.raise_()self.label.raise_()self.textBrowser.raise_()self.textBrowser_3.raise_()self.pushButton.raise_()self.pushButton_2.raise_()self.label_2.raise_()self.label_4.raise_()self.textBrowser_4.raise_()self.label_3.raise_()MainWindow.setCentralWidget(self.centralwidget)self.menubar = QtWidgets.QMenuBar(MainWindow)self.menubar.setGeometry(QtCore.QRect(0, 0, 800, 26))self.menubar.setObjectName("menubar")MainWindow.setMenuBar(self.menubar)self.statusbar = QtWidgets.QStatusBar(MainWindow)self.statusbar.setObjectName("statusbar")MainWindow.setStatusBar(self.statusbar)self.toolBar = QtWidgets.QToolBar(MainWindow)self.toolBar.setObjectName("toolBar")MainWindow.addToolBar(QtCore.Qt.TopToolBarArea, self.toolBar)self.retranslateUi(MainWindow)QtCore.QMetaObject.connectSlotsByName(MainWindow)# 模型相关变量初始化self.model = tf.keras.models.load_model(r'C:\Users\zjl15\PycharmProjects\pythonProject1\roch_classification_cnn.h5')self.path = ''self.rock_types = ['砾岩','安山岩','花岗岩','石灰岩','石英岩']# 将“导入图片”按钮与openImage函数绑定self.pushButton.clicked.connect(self.openImage)# 将“岩石分类”按钮与classify函数绑定self.pushButton_2.clicked.connect(self.classify)def retranslateUi(self, MainWindow):_translate = QtCore.QCoreApplication.translateMainWindow.setWindowTitle(_translate("MainWindow", "MainWindow"))self.label.setText(_translate("MainWindow", "岩石图像分类"))self.pushButton.setText(_translate("MainWindow", "导入图像"))self.pushButton_2.setText(_translate("MainWindow", "岩石分类"))self.label_4.setText(_translate("MainWindow", "分类结果"))self.textBrowser_3.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:24pt;\">岩石图像识别</span></p></body></html>"))self.textBrowser_4.setHtml(_translate("MainWindow","<!DOCTYPE HTML PUBLIC \"-//W3C//DTD HTML 4.0//EN\" \"http://www.w3.org/TR/REC-html40/strict.dtd\">\n""<html><head><meta name=\"qrichtext\" content=\"1\" /><style type=\"text/css\">\n""p, li { white-space: pre-wrap; }\n""</style></head><body style=\" font-family:\'SimSun\'; font-size:9pt; font-weight:400; font-style:normal;\">\n""<p align=\"center\" style=\" margin-top:0px; margin-bottom:0px; margin-left:0px; margin-right:0px; -qt-block-indent:0; text-indent:0px;\"><span style=\" font-size:11pt;\">分类结果</span></p></body></html>"))self.toolBar.setWindowTitle(_translate("MainWindow", "toolBar"))# 导入图片函数def resource_path(relative):if hasattr(sys, "_MEIPASS"):absolute_path = os.path.join(sys._MEIPASS, relative)else:absolute_path = os.path.join(relative)return absolute_path# 在原来引用该文件的地方加上这个函数 (resource_path("文件名"))def openImage(self):imgPath, imgType = QFileDialog.getOpenFileName(None, "导入图片", "", "*.jpg;;*.png;;All Files(*)")jpg = QtGui.QPixmap(imgPath).scaled(self.label_2.width(), self.label_2.height())self.label_2.setPixmap(jpg)self.path=imgPathself.label_3.setText('')def classify(self):img = Image.open(self.path)  # 读取图像img_rgb = img.split()x = np.zeros((1, 224, 224, 3))x[0,:, :, 0] = np.array(img_rgb[0]) / 255x[0,:, :, 1] = np.array(img_rgb[1]) / 255x[0,:, :, 2] = np.array(img_rgb[2]) / 255y = self.model.predict(x)result = self.rock_types[np.argmax(y)]self.label_3.setText(result)
if __name__=='__main__':QtCore.QCoreApplication.setAttribute(QtCore.Qt.AA_EnableHighDpiScaling)app=QtWidgets.QApplication(sys.argv)MainWindow=QtWidgets.QMainWindow()ui_test=Ui_MainWindow()ui_test.setupUi(MainWindow)MainWindow.show()sys.exit(app.exec_())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/337488.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

体验Photoshop:无需下载,直接在浏览器编辑图片

搜索Photoshop时&#xff0c;映入眼帘的是PS软件下载&#xff0c;自学PS软件需要多长时间&#xff0c;学PS软件有必要报班吗...PS软件的设计功能很多&#xff0c;除了常见的图像处理功能外&#xff0c;还涉及图形、文本、视频、出版等。不管你是平面设计师&#xff0c;UI/UX设计…

visual studio code 全局搜索

VScode写代码的时候&#xff0c;会经常性的需要进行查找代码&#xff0c;那么怎么在Visual Studio Code中进行查找呢&#xff0c;下面就来大家vscode全局搜索的方法。 想要在vscode全局搜索进行全局搜索&#xff0c;使用快捷键CTRLSHIFTF即可进行搜索&#xff0c;也可以在左边…

免费分享一套微信小程序图书借阅(图书管理)系统(SpringBoot后端)

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的微信小程序图书借阅(图书管理)系统(&#xff0c;分享下哈。 项目介绍 该项目是一套图书馆信息管理系统&#xff0c;包括用户小程序以及后台管理系统&#xff0c;基于SpringBootMyBatis实现。前台商城系统包…

Linux主机安全可视化运维(免费方案)

本文介绍如何使用免费的主机安全软件,在自有机房或企业网络实现对Linux系统进行可视化“主机安全”管理。 一、适用对象 本文适用于个人或企业内的Linux服务器运维场景,实现免费、高效、可视化的主机安全管理。提前发现主机存在的安全风险,全方位实时监控主机运行时入侵事…

Windows 11 Beta 22635.3646 预览版发布:中国大陆地区新增“微软电脑管家”应用

微软今天面向 Beta 频道的 Windows Insider 项目成员&#xff0c;发布了适用于 Windows 11 的 KB5037858 更新&#xff0c;用户安装后版本号升至 Build 22635.3646&#xff0c;该版本主要为中国大陆设备新增“微软电脑管家”应用。 IT之家 5 月 24 日消息&#xff0c;微软今天…

LabVIEW中PID控制器系统的噪声与扰动抑制策略

在LabVIEW中处理PID控制器系统中的噪声和外部扰动&#xff0c;需要从信号处理、控制算法优化、硬件滤波和系统设计四个角度入手。采用滤波技术、调节PID参数、增加前馈控制和实施硬件滤波器等方法&#xff0c;可以有效减少噪声和扰动对系统性能的影响&#xff0c;提高控制系统的…

PBR系列-物理材质(上)

作者&#xff1a;游梦 对PBR系列文章感兴趣还可以看前文&#xff1a; PBR系列 - 物理光源 PBR系列-光之简史 前面两篇文章分别介绍了物理光源与光学研究简史&#xff0c;在对光有了简单认识之后&#xff0c;再认识物理材质会发现其实本质上还是对光的研究&#xff0c;再深入…

阿里云 通过EIP实现VPC下的SNAT以及DNAT

192.168.0.85 有公网地址192.1680.95无公网地址 在192.168.0.85&#xff08;有公网地址服务器上操作&#xff09; #开启端口转发 echo "net.ipv4.ip_forward 1" >> /etc/sysctl.conf sysctl -p#仅允许192.168.0.95 iptables -t nat -I POSTROUTING -s 192.16…

SqliSniper:针对HTTP Header的基于时间SQL盲注模糊测试工具

关于SqliSniper SqliSniper是一款基于Python开发的强大工具&#xff0c;该工具旨在检测HTTP请求Header中潜在的基于时间的SQL盲注问题。 该工具支持通过多线程形式快速扫描和识别目标应用程序中的潜在漏洞&#xff0c;可以大幅增强安全评估过程&#xff0c;同时确保了速度和效…

Pycharm使用时的红色波浪线报错——形如‘break‘ outside loop

背景&#xff1a; 我在一个方法中&#xff0c;写了一个if判断&#xff0c;写了一个break&#xff0c;期望终止这个函数&#xff0c;编辑器出现报错 形如下图 视频版问题教程&#xff1a; Pycharm下出现波浪线报错&#xff0c;形如break outside loop 过程&#xff1a; 很奇…

ROS2在RVIZ2中加载机器人urdf模型

参考ROS2-rviz2显示模型 我这边用的solid works生成的urdf以及meshes&#xff0c;比参考的方法多了meshes 问题一&#xff1a;Error retrieving file [package://rm_dcr_description/meshes/leftarm_link7.STL]: Package [rm_dcr_description] does not exist 这个是urdf模型中…

python-pytorch编写transformer模型实现问答0.5.00--训练和预测

python-pytorch编写transformer模型实现问答0.5.00--训练和预测 背景代码训练预测效果 背景 代码写不了这么长&#xff0c;接上一篇 https://blog.csdn.net/m0_60688978/article/details/139360270 代码 # 定义解码器类 n_layers 6 # 设置 Decoder 的层数 class Decoder(…

向量数据库引领 AI 创新——Zilliz 亮相 2024 亚马逊云科技中国峰会

2024年5月29日&#xff0c;亚马逊云科技中国峰会在上海召开&#xff0c;此次峰会聚集了来自全球各地的科技领袖、行业专家和创新企业&#xff0c;探讨云计算、大数据、人工智能等前沿技术的发展趋势和应用场景。作为领先的向量数据库技术公司&#xff0c;Zilliz 在本次峰会上展…

SpringBoot+layui实现Excel导入操作

excel导入步骤 第三方插件引入插件 效果图 &#xff08;方法1&#xff09;代码实现&#xff08;方法1&#xff09;Html代码&#xff08; 公共&#xff09;下载导入模板 js实现 &#xff08;方法1&#xff09;上传文件实现 效果图&#xff08;方法2&#xff09;代码实现&#xf…

mimkatz获取windows10明文密码

目录 mimkatz获取windows10明文密码原理 lsass.exe进程的作用 mimikatz的工作机制 Windows 10的特殊情况 实验 实验环境 实验工具 实验步骤 首先根据版本选择相应的mimikatz 使用管理员身份运行cmd 修改注册表 ​编辑 重启 重启电脑后打开mimikatz 在cmd切换到mi…

Matlab|基于粒子群算法优化Kmeans聚类的居民用电行为分析

目录 主要内容 部分代码 结果一览 下载链接 主要内容 在我们研究电力系统优化调度模型的过程中&#xff0c;由于每天负荷和分布式电源出力随机性和不确定性&#xff0c;可能会优化出很多的结果&#xff0c;但是经济调度模型试图做到通用策略&#xff0c;同样的策…

HarmonyOS鸿蒙学习笔记(25)相对布局 RelativeContainer详细说明

RelativeContainer 简介 前言核心概念官方实例官方实例改造蓝色方块改造center 属性说明参考资料 前言 RelativeContainer是鸿蒙的相对布局组件&#xff0c;它的布局很灵活&#xff0c;可以很方便的控制各个子UI 组件的相对位置&#xff0c;其布局理念有点类似于android的约束…

如何看待时间序列与机器学习?

GPT-4o 时间序列与机器学习的关联在于&#xff0c;时间序列数据是一种重要的结构化数据形式&#xff0c;而机器学习则是一种强大的工具&#xff0c;用于从数据中提取有用的模式和信息。在很多实际应用中&#xff0c;时间序列与机器学习可以结合起来&#xff0c;发挥重要作用。…

基于 Apache Doris 的实时/离线一体化架构,赋能中国联通 5G 全连接工厂解决方案

作者&#xff1a;田向阳&#xff0c;联通西部创新研究院 大数据专家 共创&#xff1a;SelectDB 技术团队 导读&#xff1a; 数据是 5G 全连接工厂的核心要素&#xff0c;为支持全方位的数据收集、存储、分析等工作的高效进行&#xff0c;联通 5G 全连接工厂从典型的 Lambda 架…

利用ArcGIS Python批量拼接遥感影像(arcpy batch processing)

本篇文章将说明如何利用ArcGIS 10.1自带的Python IDLE进行遥感影像的批量拼接与裁剪。 1.运行环境&#xff1a;ArcGIS10.1 (安装传送门)、Python IDLE 2.数据来源&#xff1a;地理空间数据云 GDEMV2 30M分辨率数字高程数据 3.解决问题&#xff1a;制作山西省的DEM影像 如下…