文章目录
一、前期工作
- 导入库
- 数据集加载
二、构建CNN模型
三、训练过程曲线函数
四、训练模型函数
五、训练模型与结果
六、验证
大家好,今天给大家带来一个卷积神经网络(CNN)数学图形识别项目(简单入门版),这个是人工智能解题的基础,机器首先通过题目识别出题目中的文字和图形,读懂题目的含义,这个是个相对复杂的过程。就在今年的1月4日,麻省理工学院等四所高校的联合研究团队,发布了一项最新研究成果:他们开发了一个神经网络,可以解答出微积分、线性代数等大学数学题。不管是要求计算数值,还是写方程式,或者画出函数图形,都能轻易解答,正确率达到了100%。要知道,在短短几个月前,人工智能解答类似的题,最高正确率不到10%。
解数学题不是直接算公式,而是需要你去理解文字,此外,有的题还需要联系上下文,并且包括了一些隐含条件。比如,一道概率论的题目,问:“在扑克游戏中,拿到两副对子的概率是多少?”这道题在人看来很清楚,但是对于计算机来说,其实是有很多隐含条件的。比如,一副扑克有54张牌、4种花色+两张鬼牌等等。人工智能不知道这些隐含条件,就没法算题。
一、前期工作
1.导入库
import tensorflow as tf
# from keras import keras.layers
import matplotlib.pyplot as plt
from time import *
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,MaxPooling2D,Flatten,Conv2D,experimental
2. 数据集加载
data_dir = "./data/图形识别训练"batch_size = 12
img_height = 224
img_width = 224train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=12,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=12,image_size=(img_height, img_width),batch_size=batch_size)
二、构建DenseNet模型
model = tf.keras.applications.DenseNet121(weights='imagenet')
model.summary()# 设置初始学习率
initial_learning_rate = 1e-3lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate,decay_steps=5, # 敲黑板!!!这里是指 steps,不是指epochsdecay_rate=0.96, # lr经过一次衰减就会变成 decay_rate*lrstaircase=True)# 将指数衰减学习率送入优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(optimizer=optimizer,loss ='sparse_categorical_crossentropy',metrics =['accuracy'])
三、训练过程的曲线函数
def show_loss_acc(history):# 从history中提取模型训练集和验证集准确率信息和误差信息acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']# 按照上下结构将图画输出plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.ylabel('Accuracy')plt.ylim([min(plt.ylim()), 1])plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.ylabel('Cross Entropy')plt.title('Training and Validation Loss')plt.xlabel('epoch')plt.savefig('results/results_cnn.png', dpi=100)plt.show()
四、训练模型函数
def train(epochs):# 开始训练,记录开始时间begin_time = time()AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)# print(class_names)# 加载模型# model = model_load(class_num=len(class_names))# 指明训练的轮数epoch,开始训练history = model.fit(train_ds,validation_data=val_ds,epochs=epochs)# todo 保存模型, 修改为你要保存的模型的名称model.save("models/cnn_fv.h5")# 记录结束时间end_time = time()run_time = end_time - begin_timeprint('该循环程序运行时间:', run_time, "s") # 该循环程序运行时间: 1.4201874732# 绘制模型训练过程图show_loss_acc(history)
图形识别训练数集里面共有4类:circular(圆形),parabola(抛物线),square(正方形),triangle(三角形)
triangle(三角形)
square(正方形)
五、训练模型与结果
train(epochs=6) #训练6次41/41 [==============================] - 218s 5s/step - loss: 1.5289 - accuracy: 0.7224 - val_loss: 21.6089 - val_accuracy: 0.0410
Epoch 2/6
41/41 [==============================] - 206s 5s/step - loss: 0.2921 - accuracy: 0.9204 - val_loss: 4.2023 - val_accuracy: 0.6393
Epoch 3/6
41/41 [==============================] - 210s 5s/step - loss: 0.0962 - accuracy: 0.9673 - val_loss: 0.4482 - val_accuracy: 0.9180
Epoch 4/6
41/41 [==============================] - 209s 5s/step - loss: 0.0406 - accuracy: 0.9898 - val_loss: 0.0980 - val_accuracy: 0.9672
Epoch 5/6
41/41 [==============================] - 205s 5s/step - loss: 0.0149 - accuracy: 1.0000 - val_loss: 0.0269 - val_accuracy: 0.9918
Epoch 6/6
41/41 [==============================] - 220s 5s/step - loss: 0.0061 - accuracy: 1.0000 - val_loss: 0.0220 - val_accuracy: 0.9918
该循环程序运行时间: 1269.1464262008667 s
测试和验证结果准确率图:
迭代6次后,训练集准确率高达100%,验证集准确率高达99.18%
六、验证
输入图片:
预测结果:parabola(抛物线)
后续将会通过OCR识别题目中的文字,以及Latex数学公式的识别,识别题目含义,最终结果调用抛物线的解题程序,进行简单的解题。
敬请关注,数据集的获取私信我!
往期作品:
深度学习实战项目
1.深度学习实战1-(keras框架)企业数据分析与预测
2.深度学习实战2-(keras框架)企业信用评级与预测
3.深度学习实战3-文本卷积神经网络(TextCNN)新闻文本分类
4.深度学习实战4-卷积神经网络(DenseNet)数学图形识别+题目模式识别
5.深度学习实战5-卷积神经网络(CNN)中文OCR识别项目
6.深度学习实战6-卷积神经网络(Pytorch)+聚类分析实现空气质量与天气预测
7.深度学习实战7-电商产品评论的情感分析
8.深度学习实战8-生活照片转化漫画照片应用
9.深度学习实战9-文本生成图像-本地电脑实现text2img
10.深度学习实战10-数学公式识别-将图片转换为Latex(img2Latex)
11.深度学习实战11(进阶版)-BERT模型的微调应用-文本分类案例
12.深度学习实战12(进阶版)-利用Dewarp实现文本扭曲矫正
13.深度学习实战13(进阶版)-文本纠错功能,经常写错别字的小伙伴的福星
14.深度学习实战14(进阶版)-手写文字OCR识别,手写笔记也可以识别了
15.深度学习实战15(进阶版)-让机器进行阅读理解+你可以变成出题者提问
16.深度学习实战16(进阶版)-虚拟截图识别文字-可以做纸质合同和表格识别
17.深度学习实战17(进阶版)-智能辅助编辑平台系统的搭建与开发案例
18.深度学习实战18(进阶版)-NLP的15项任务大融合系统,可实现市面上你能想到的NLP任务
19.深度学习实战19(进阶版)-ChatGPT的本地实现部署测试,自己的平台就可以实现ChatGPT
…(待更新)