二、模型训练与优化(1):构建并训练模型

目录

1. 安装 Anaconda(推荐)

步骤:

2. 创建并激活虚拟环境

步骤:

3. 安装必要的库

步骤:

4. 编写训练脚本

步骤:

5. 运行训练脚本

步骤:

总结:


在完成了准备工作的基础上,接下来进行模型训练与优化。以 MNIST 手写数字识别 为例,使用 TensorFlowKeras 构建、训练并优化一个简单的神经网络模型,并为后续部署到 STM32 做准备。

目前已经完成了 Python、TensorFlow/Keras 等环境的安装与配置,具体准备工作见博客:

一、准备工作(1):在计算机中安装Python-CSDN博客

一、准备工作(2):部署TensorFlow和Keras-CSDN博客


1. 安装 Anaconda(推荐)

Anaconda 是一个非常流行的 Python 发行版,包含了大量的数据科学和机器学习库,并且提供了虚拟环境管理工具 conda。使用 Anaconda 可以简化库的安装和环境配置,特别适合初学者。

步骤:

(1)下载 Anaconda 安装程序:

访问 Anaconda 官方下载页面。填写邮箱后,相应下载链接就会发送到邮箱,点击下载

在页面中,找到适用于 Windows 的下载选项,点击下载 64-bit Graphical Installer(图形安装程序)。

(2)运行安装程序:

  • 双击下载的 .exe 文件启动安装向导。
  • 重要:在安装过程中,建议勾选 “Add Anaconda to my PATH environment variable”“Register Anaconda as my default Python” 选项。这将方便在命令行中使用 condapython 命令。
  • 按照安装向导的提示完成安装过程。

(3)验证安装:

  • 打开 Anaconda Prompt:
    • 点击 “开始” 菜单,搜索 “Anaconda Prompt” 并打开。

  • 检查 Anaconda 版本:
    • 在 Anaconda Prompt 中输入:
      conda --version
      
      显示类似 conda 4.x.x 的版本号。


2. 创建并激活虚拟环境

虚拟环境可以隔离不同项目的依赖,避免包版本冲突。接下来创建一个名为 tf_env 的虚拟环境,并在其中安装 TensorFlow 和 Keras。

步骤:

  1. 打开 Anaconda Prompt:

    • 点击 “开始” 菜单,搜索 “Anaconda Prompt” 并打开。
  2. 创建虚拟环境:

    • 在 Anaconda Prompt 中输入以下命令,并按回车:
      conda create -n tf_env python=3.10
      
    • 解释:
      • conda create:创建一个新的环境。
      • -n tf_env:环境名称为 tf_env
      • python=3.10:指定 Python 版本为 3.10。python=3.10:指定 Python 版本为 3.10。这里为什么指定版本为3.10,有什么依据-CSDN博客
  3. 确认创建:

    • 安装过程中会显示即将安装的包列表,输入 y 并按回车确认

  4. 激活虚拟环境:

    • 创建完成后,输入以下命令激活环境:
      conda activate tf_env
      
    • 提示:激活后,命令提示符前会显示 (tf_env),表示当前处于 tf_env 环境中。


3. 安装必要的库

在激活的虚拟环境中,安装 TensorFlow、Keras 以及其他必要的库(如 numpy、matplotlib)。

步骤:

  1. 更新 conda:

    • 虽然不一定必要,但建议更新 conda 以获取最新的包信息:
      conda update conda
      
    • 输入 y 并按回车确认更新。
  2. 安装 TensorFlow 和 Keras:

    • 输入以下命令并按回车:
      conda install tensorflow keras
      
    • 这里安装时间比较久,可以多等一会
    • 解释:
      • conda install:使用 conda 安装包。
      • tensorflow keras:安装 TensorFlow 和 Keras。
  3. 安装其他库:

    • 还可以安装 numpymatplotlib,用于数据处理和可视化:
      conda install numpy matplotlib
      
  4. 确认安装:

    • 安装过程中会显示将要安装的包列表,输入 y 并按回车确认。

4. 编写训练脚本

现在,开始编写一个 Python 脚本来构建、训练并评估我们的模型。我们将使用 MNIST 数据集作为示例。

步骤:

  1. 选择代码编辑器:

    可以使用任何文本编辑器(如 Notepad、Notepad++)或集成开发环境(如 VSCode、PyCharm)。推荐使用 VSCode,它功能强大且免费。

    • 下载 VSCode(如果尚未安装):
      • 访问 VSCode 官方下载页面。
      • 下载并安装适用于 Windows 的版本。
  2. 创建项目文件夹:

    • 在电脑中创建一个新的文件夹: C:\Users\FCZ\Desktop\Projects\mnist_project
  3. 创建 Python 脚本:

    • 打开 VSCode。
    • 点击 “文件” > “打开文件夹”,选择刚才创建的 mnist_project 文件夹。
    • 在左侧的文件资源管理器中,右键点击空白区域,选择 “新建文件”,命名为 train_mnist.py
  4. 编写训练代码:

    • train_mnist.py 中复制并粘贴以下代码:

      import tensorflow as tf
      import numpy as np
      import matplotlib.pyplot as pltdef main():# 1. 加载 MNIST 数据集(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()# 2. 数据预处理# 将灰度像素值(0~255)归一化到 0~1 之间x_train = x_train.astype("float32") / 255.0x_test  = x_test.astype("float32") / 255.0# 将图像张量拉伸为 (28*28) 向量x_train = x_train.reshape(-1, 28 * 28)x_test  = x_test.reshape(-1, 28 * 28)# 3. 构建模型model = tf.keras.models.Sequential([tf.keras.layers.Dense(128, activation='relu', input_shape=(784,)),tf.keras.layers.Dense(10, activation='softmax')])# 4. 编译模型model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 5. 训练模型history = model.fit(x_train, y_train,epochs=5,batch_size=64,validation_split=0.1  # 从训练数据中分出 10% 做验证)# 6. 模型评估test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)print(f"\n测试集上的准确率: {test_acc:.4f}")# 7. 保存模型model.save("mnist_model.h5")# 或者保存为 TensorFlow SavedModel 格式# model.save("mnist_saved_model", save_format="tf")# 8. 可视化训练过程plot_history(history)def plot_history(history):"""可视化训练曲线"""acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(len(acc))plt.figure(figsize=(12, 4))# 绘制准确率曲线plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='训练准确率')plt.plot(epochs_range, val_acc, label='验证准确率')plt.legend(loc='lower right')plt.title('训练和验证准确率')# 绘制损失曲线plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='训练损失')plt.plot(epochs_range, val_loss, label='验证损失')plt.legend(loc='upper right')plt.title('训练和验证损失')plt.show()if __name__ == "__main__":main()
      
    • 保存文件:Ctrl + S 保存 train_mnist.py


5. 运行训练脚本

现在,可以运行刚才编写的训练脚本,开始训练模型。

步骤:

  1. 打开 Anaconda Prompt 并激活虚拟环境:

    • 打开 Anaconda Prompt
    • 激活之前创建的虚拟环境 tf_env
      conda activate tf_env
      
    • 提示:命令提示符前会显示 (tf_env)
  2. 导航到项目目录:

    • 在 Anaconda Prompt 中,输入以下命令并按回车:
      cd C:\Users\FCZ\Desktop\Projects\mnist_project
      
  3. 运行训练脚本:

    • 输入以下命令并按回车:
      python train_mnist.py
      
    • 解释: 这将执行 train_mnist.py 脚本,开始训练模型。
  4. 观察训练过程:

    • 会看到类似以下的输出:
    • 说明:
      • 每个 epoch 显示训练损失、训练准确率、验证损失和验证准确率。
      • 训练完成后,会显示测试集上的准确率。
      • 如果安装了 matplotlib,还会弹出一个窗口显示训练和验证的准确率及损失曲线。

  脚本训练过程中报的OpenMP的错误信息: 

OMP: Error #15: Initializing libiomp5md.dll, but found libiomp5 already initialized.
OMP: Hint This means that multiple copies of the OpenMP runtime have been linked into the program. That is dangerous, since it can degrade performance or cause incorrect results. ...

解决办法:TensorFlow 和 Keras 进行模型训练时遇到了 OpenMP 的错误信息-CSDN博客

总结:

目前完成了模型的训练,接下来对模型结果分析以及优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/503833.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JavaEE初阶——计算机工作原理

一、什么是JavaEE JavaEE(Java Platform,Enterprise Edition)是sun公司(2009年4月20日甲骨文将其收购)推出的企业级应用程序版本。这个版本以前称为 J2EE。能够帮助我们开发和部署可移植、健壮、可伸缩且安全的服务器…

医学图像分析工具02:3D Slicer || 医学影像可视化与分析工具 支持第三方插件

3D Slicer 是一款功能全面的开源医学影像分析软件,广泛应用于影像处理、三维建模、影像配准和手术规划等领域。它支持多种医学影像格式(如 DICOM、NIfTI)和丰富的插件扩展,是神经科学、放射学和生物医学研究中不可或缺的工具。 在…

STM32裸机开发转FreeRTOS教程

目录 1. 简介2. RTOS设置(1)分配内存(2)查看任务剩余空间(3)使用osDelay 3. 队列的使用(1)创建队列(1)直接传值和指针传值(2)发送/接收…

使用高云小蜜蜂GW1N-2实现MIPI到LVDS(DVP)转换案例分享

作者:Hello,Panda 大家晚上好,熊猫君又来了。 今天要分享的是一个简单的MIPI到LVDS(DVP)接口转换的案例。目的就是要把低成本FPGA的应用潜力充分利用起来。 一、应用背景 这个案例的应用背景是:现在还在…

Express 加 sqlite3 写一个简单博客

例图: 搭建 命令: 前提已装好node.js 开始创建项目结构 npm init -y package.json:{"name": "ex01","version": "1.0.0","main": "index.js","scripts": {"test": &q…

git撤回提交、删除远端某版本、合并指定版本的更改

撤回提交 vscode的举例 一、只提交了还未推送的情况下 1.撤回最后一次提交,把最后一次提交的更改放到暂存区 git reset --soft HEAD~12.撤回最后一次提交,把最后一次提交的更改放到工作区 git reset --mixed HEAD~13.撤回最后一次提交,不…

[Linux]redis5.0.x升级至7.x完整操作流程

1. 从官网下载最新版redis: 官网地址:https://redis.io/download 注:下载需要的登录,如果选择使用github账号登录,那么需要提前在github账号中取消勾选“Keep my email addresses private”(隐藏我的邮箱…

xss-labs关卡记录15-20关

十五关 随便传一个参数,然后右击查看源码发现,这里有一个陌生的东西,就是ng-include。这里就是: ng-include指令就是文件包涵的意思,用来包涵外部的html文件,如果包涵的内容是地址,需要加引号。…

数据库回滚:大祸临头时

原文地址 什么是数据库回滚? 数据库技术中,回滚是通过撤销对数据库所做的一项或多项更改,将数据库返回到先前状态的操作。它是维护数据完整性和从错误中恢复的重要机制。 什么时候需要数据库回滚? 数据库回滚在以下几个场景中很…

年会抽奖Html

在这里插入图片描述 <!-- <video id"backgroundMusic" src"file:///D:/background.mp3" loop autoplay></video> --> <divstyle"width: 290px; height: 580px; margin-left: 20px; margin-top: 20px; background: url(D:/nianhu…

基于FPGA的出租车里程时间计费器

基于FPGA的出租车里程时间计费器 功能描述一、系统框图二、verilog代码里程增加模块时间增加模块计算价格模块上板视频演示 总结 功能描述 &#xff08;1&#xff09;&#xff1b;里程计费功能&#xff1a;3公里以内起步价8元&#xff0c;超过3公里后每公里2元&#xff0c;其中…

nginx-链路追踪(trace)实现

一. 需求场景&#xff1a; 在日常运维工作中&#xff0c;会经常遇到在有多重调用链的场景下&#xff0c;如请求遇到非致命error时&#xff0c;在各环节的定位会非常麻烦&#xff0c;举个例子&#xff1a;比如说&#xff0c;在一个有多重调用链的服务环境下&#xff0c;一个请求…

c#使用SevenZipSharp实现压缩文件和目录

封装了一个类&#xff0c;方便使用SevenZipSharp&#xff0c;支持加入进度显示事件。 双重加密压缩工具范例&#xff1a; using SevenZip; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading.…

MySQL和Hive中的行转列、列转行

水善利万物而不争&#xff0c;处众人之所恶&#xff0c;故几于道&#x1f4a6; 文章目录 MySQL1.行转列2.列转行 Hive1.行转列2.列转行(1)侧窗(2)union MySQL 1.行转列 把多行转成列。直接group&#xff0c;sum(if()) 2.列转行 Hive 1.行转列 select name,sum(if(kmshuxu…

快速上手:采用Let‘sEncrypt免费SSL证书配置网站Https (示例环境:Centos7.9+Nginx+Let‘sEncrypt)

1 关于Let’s Encrypt与Cerbot DNS验证 Let’s Encrypt 是一个提供 免费证书 的 认证机构。 Cerbot 是 Let’s Encrypt 提供的一个工具&#xff0c;用于自动化生成、验证和续订证书。 DNS验证是 Cerbot 支持的验证方式之一。相比 HTTP 验证或 TLS-ALPN 验证&#xff0c;DNS …

【Unity3D】Text文本文字掉落效果

相关技术&#xff1a;Text、TextMesh、Rigidbody&#xff08;刚体&#xff09;、BoxCollider&#xff08;碰撞体&#xff09;、TextGenerator、文本网格、文字网格 原理&#xff1a;使用UGUI Text获取其文字的每个字符网格坐标&#xff0c;转世界坐标生成对应的3D文本(TextMesh…

flutter 专题二十四 Flutter性能优化在携程酒店的实践

Flutter性能优化在携程酒店的实践 一 、前言 携程酒店业务使用Flutter技术开发的时间快接近两年&#xff0c;这期间有列表页、详情页、相册页等页面使用了Flutter技术栈进行了跨平台整合&#xff0c;大大提高了研发效率。在开发过程中&#xff0c;也遇到了一些性能相关问题和…

设计模式 行为型 命令模式(Command Pattern)与 常见技术框架应用 解析

命令模式&#xff08;Command Pattern&#xff09;是一种行为型设计模式&#xff0c;它旨在将请求发送者和接收者解耦&#xff0c;通过将一个请求封装为一个对象&#xff0c;从而允许参数化客户端对象以进行不同的请求、排队请求或记录请求&#xff0c;并支持可撤销操作。 在软…

NodeLocal DNS 全攻略:从原理到应用实践

文章目录 一、NodeLocal DNS是什么&#xff1f;二、为什么使用NodeLocal DNS&#xff1f;三、工作原理架构图四、安装NodeLocal DNS五、在应用中使用NodeLocal DNSCache六、验证 一、NodeLocal DNS是什么&#xff1f; NodeLocal DNSCache 通过在集群节点上运行一个 DaemonSet …

jenkins入门12-- 权限管理

Jenkins的权限管理 由于jenkins默认的权限管理体系不支持用户组或角色的配置&#xff0c;因此需要安装第三发插件来支持角色的配置&#xff0c;我们使用Role-based Authorization Strategy 插件 只有项目读权限 只有某个项目执行权限