Python 课程14-TensorFlow

前言

TensorFlow 是由 Google 开发的一个开源深度学习框架,广泛应用于机器学习和人工智能领域。它具有强大的计算能力,能够运行在 CPU、GPU 甚至 TPU 上,适用于从小型模型到大规模生产系统的各种应用场景。通过 TensorFlow,你可以构建和训练复杂的深度学习模型,如神经网络、卷积神经网络(CNN)、递归神经网络(RNN)等。

本教程将为你详细介绍如何使用 TensorFlow 构建和训练模型,包括其每个重要指令的说明与代码示例。无论你是初学者还是有经验的开发者,都可以通过本教程全面掌握 TensorFlow 的基础知识和应用。


目录

  1. TensorFlow 基础

    • 安装 TensorFlow
    • 创建张量(Tensors)
    • 张量操作与基本函数
    • 使用 GPU 加速
  2. 神经网络基础

    • 使用 Keras 构建简单的神经网络
    • 模型编译与训练
    • 模型评估与预测
    • 保存与加载模型
  3. 高级操作

    • 自定义训练循环
    • 使用 Dataset 进行数据管道处理
    • 自定义层与激活函数
    • 自动微分与梯度计算
  4. 卷积神经网络 (CNN)

    • 卷积层与池化层
    • 构建简单的 CNN 模型
    • 训练与评估 CNN 模型
  5. 递归神经网络 (RNN)

    • 构建简单的 RNN 模型
    • 使用 LSTM 处理序列数据
    • 训练与评估 RNN 模型

1. TensorFlow 基础

安装 TensorFlow

在开始使用 TensorFlow 之前,你需要安装它。可以通过 pip 安装 TensorFlow:

pip install tensorflow

 你可以使用以下代码检查是否成功安装:

import tensorflow as tf
print(tf.__version__)  # 输出 TensorFlow 的版本号
创建张量(Tensors)

张量(Tensors)是 TensorFlow 中的基本数据结构,类似于 NumPy 的数组,可以是多维数组。张量是不可变的,它们可以在 CPU 或 GPU 上运行。

  • 创建常量张量
    import tensorflow as tf# 创建一个常量张量
    tensor = tf.constant([[1, 2], [3, 4]])
    print(tensor)
    
  • 创建随机张量
    # 创建一个 3x3 的随机张量
    random_tensor = tf.random.normal([3, 3])
    print(random_tensor)
    
  • 创建全 0 或全 1 张量
    # 创建一个全 0 张量
    zero_tensor = tf.zeros([2, 3])# 创建一个全 1 张量
    one_tensor = tf.ones([2, 3])
    
    张量操作与基本函数

    TensorFlow 提供了丰富的张量操作函数,如加法、减法、矩阵乘法等。它们与 NumPy 的操作类似。

  • 加法与减法
    a = tf.constant([1, 2, 3])
    b = tf.constant([4, 5, 6])# 张量加法
    sum_tensor = tf.add(a, b)
    print(sum_tensor)# 张量减法
    sub_tensor = tf.subtract(a, b)
    print(sub_tensor)
    
  • 矩阵乘法
    matrix_a = tf.constant([[1, 2], [3, 4]])
    matrix_b = tf.constant([[5, 6], [7, 8]])# 矩阵乘法
    matmul_tensor = tf.matmul(matrix_a, matrix_b)
    print(matmul_tensor)
    
  • 重塑张量
    # 将张量重塑为 2x3
    reshaped_tensor = tf.reshape(tensor, [2, 3])
    print(reshaped_tensor)
    
    使用 GPU 加速

    如果你的计算机配有支持 CUDA 的 NVIDIA GPU,可以使用 GPU 加速深度学习任务。TensorFlow 会自动检测可用的 GPU 并将计算任务分配给它。

  • 查看是否使用 GPU
print("Is GPU available:", tf.config.list_physical_devices('GPU'))

如果 TensorFlow 检测到 GPU,则会在进行运算时优先使用 GPU 加速。


2. 神经网络基础

Keras 是 TensorFlow 中的高层 API,简化了神经网络的构建过程。使用 Keras,可以快速构建并训练神经网络模型。

使用 Keras 构建简单的神经网络

我们将通过 Keras 构建一个简单的前馈神经网络,并使用它对 MNIST 手写数字数据集进行分类。

  • 加载数据集
    from tensorflow.keras.datasets import mnist# 加载 MNIST 数据集
    (X_train, y_train), (X_test, y_test) = mnist.load_data()# 数据标准化:将像素值从 [0, 255] 缩放到 [0, 1]
    X_train = X_train / 255.0
    X_test = X_test / 255.0
    
  • 构建模型
    from tensorflow.keras import layers, models# 创建一个顺序模型
    model = models.Sequential()# 添加层
    model.add(layers.Flatten(input_shape=(28, 28)))  # 将 28x28 的图像展平为 784 维向量
    model.add(layers.Dense(128, activation='relu'))  # 添加全连接层
    model.add(layers.Dense(10, activation='softmax'))  # 输出层,10 个类别# 打印模型结构
    model.summary()
    
    模型编译与训练

    在编译模型时,你需要指定损失函数、优化器和评估指标。

  • 编译模型
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
    
  • 训练模型
# 训练模型
model.fit(X_train, y_train, epochs=5, batch_size=32)
模型评估与预测
  • 评估模型
    # 在测试集上评估模型
    test_loss, test_acc = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {test_acc:.2f}")
    
  • 使用模型进行预测
    # 使用训练好的模型进行预测
    predictions = model.predict(X_test)
    print(predictions[0])  # 输出第一个样本的预测结果
    
    保存与加载模型

    TensorFlow 支持将模型保存为文件,以便以后加载并继续使用。

  • 保存模型
    model.save('mnist_model.h5')
    
  • 加载模型
    from tensorflow.keras.models import load_model# 加载保存的模型
    loaded_model = load_model('mnist_model.h5')
    

    3. 高级操作

    自定义训练循环

    虽然 Keras 提供了简洁的 fit() 方法进行训练,但你也可以通过自定义训练循环更灵活地控制训练过程。

import tensorflow as tf# 自定义训练循环
for epoch in range(5):print(f"Epoch {epoch+1}")for batch, (X_batch, y_batch) in enumerate(train_dataset):with tf.GradientTape() as tape:# 前向传播logits = model(X_batch, training=True)loss_value = loss_fn(y_batch, logits)# 反向传播grads = tape.gradient(loss_value, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))print(f"Batch {batch}: Loss = {loss_value:.4f}")
使用 Dataset 进行数据管道处理

TensorFlow Dataset API 是一个用于加载和处理数据的高效工具,适合大规模数据集。

from tensorflow.data import Dataset# 创建 TensorFlow 数据集
dataset = Dataset.from_tensor_slices((X_train, y_train))# 数据集预处理
dataset = dataset.shuffle(buffer_size=1024).batch(32).repeat()# 迭代数据集
for batch_X, batch_y in dataset.take(1):print(batch_X.shape, batch_y.shape)
自定义层与激活函数

你可以使用 TensorFlow 创建自定义层或激活函数,以适应特殊需求。

  • 自定义层
    class MyCustomLayer(tf.keras.layers.Layer):def __init__(self, units=32):super(MyCustomLayer, self).__init__()self.units = unitsdef build(self, input_shape):self.w = self.add_weight(shape=(input_shape[-1], self.units),initializer='random_normal', trainable=True)self.b = self.add_weight(shape=(self.units,), initializer='zeros', trainable=True)def call(self, inputs):return tf.matmul(inputs, self.w) + self
    # 使用自定义层
    model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),MyCustomLayer(64),  # 自定义层tf.keras.layers.Dense(10, activation='softmax')
    ])model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 打印模型结构
    model.summary()
  • 自定义激活函数
    # 定义一个自定义激活函数
    def custom_activation(x):return tf.nn.relu(x) * tf.math.sigmoid(x)# 使用自定义激活函数
    model = tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128),tf.keras.layers.Activation(custom_activation),  # 自定义激活函数tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    自动微分与梯度计算

    TensorFlow 提供了自动微分的功能,允许你通过 tf.GradientTape 计算梯度,并用于优化模型。

  • 使用 tf.GradientTape 计算梯度
    # 创建模型和优化器
    model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
    optimizer = tf.keras.optimizers.Adam()# 定义损失函数
    def compute_loss(y_true, y_pred):return tf.reduce_mean(tf.losses.mean_squared_error(y_true, y_pred))# 前向传播和梯度计算
    x = tf.random.normal([3, 3])
    y = tf.random.normal([3, 10])with tf.GradientTape() as tape:y_pred = model(x)loss = compute_loss(y, y_pred)# 计算梯度
    gradients = tape.gradient(loss, model.trainable_variables)# 更新模型参数
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    


    4. 卷积神经网络 (CNN)

    卷积神经网络(CNN)特别适合处理图像数据。它通过卷积层提取图像的局部特征,能够在分类、物体检测等任务中取得优异的效果。

    卷积层与池化层
  • 卷积层:提取图像局部特征。
  • 池化层:减少特征图的尺寸,降低计算量,同时保留重要信息。
    from tensorflow.keras import layers# 构建 CNN 模型
    model = tf.keras.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),  # 池化层layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')
    ])# 打印模型结构
    model.summary()
    
    构建简单的 CNN 模型

    我们将使用 MNIST 数据集来训练一个简单的 CNN 模型。

  • 加载并预处理图像数据
    # 加载 MNIST 数据集
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()# 调整输入形状,增加通道维度(因为卷积层需要4D输入)
    X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255
    X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255
    
  • 编译并训练 CNN 模型
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
    model.fit(X_train, y_train, epochs=5, batch_size=64)
    
    训练与评估 CNN 模型
  • 评估模型
    # 在测试集上评估 CNN 模型
    test_loss, test_acc = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {test_acc:.2f}")
    

    5. 递归神经网络 (RNN)

    递归神经网络(RNN)适用于处理序列数据,如时间序列和自然语言处理任务。RNN 可以记忆之前的信息,并用于后续的输入,从而在处理时间依赖关系时非常有效。

    构建简单的 RNN 模型

    我们可以使用 LSTM(长短期记忆网络) 这一特殊类型的 RNN,来处理序列数据,如股票价格预测或文本分类。

  • 使用 LSTM 构建 RNN 模型
    from tensorflow.keras import layers# 构建 RNN 模型
    model = tf.keras.Sequential([layers.LSTM(128, input_shape=(10, 1)),  # 假设输入序列长度为10layers.Dense(10, activation='softmax')  # 输出层,10 个类别
    ])# 打印模型结构
    model.summary()
    
    使用 LSTM 处理序列数据

    我们将模拟序列数据,并使用 LSTM 进行预测。

  • 生成序列数据
    import numpy as np# 创建模拟的时间序列数据
    X_train = np.random.rand(1000, 10, 1)  # 1000 个样本,每个样本长度为 10
    y_train = np.random.randint(0, 10, 1000)  # 10 个类别的随机标签
    
  • 编译并训练模型
    model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
    model.fit(X_train, y_train, epochs=5, batch_size=32)
    
    训练与评估 RNN 模型
  • 评估模型
    # 假设我们有测试集 X_test 和 y_test
    test_loss, test_acc = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {test_acc:.2f}")
    

    结论

    通过本详细的 TensorFlow 教程,你已经掌握了构建和训练深度学习模型的基本知识。我们从最基础的张量操作开始,逐步构建神经网络,并使用 Keras 高层 API 来快速实现深度学习模型。你还学习了如何通过自定义训练循环和层实现更高级的功能,以及使用 CNNRNN 处理图像和序列数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/426156.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Unity+LeapMotion2的使用

开始吧 导入步骤1.到官网下载软件并安装2.安装插件3.场景中添加检测管理器4.场景中添加手部模型 更多细节 导入步骤 1.到官网下载软件并安装 地址 重启电脑后连接设备 可以看到连接成功 2.安装插件 (也可以看官方教程) Project—>PackageManag…

从AI应用排行榜选择AI产品(9月)

2024年9月13日,OpenAI公司宣布推出其全新的AI模型:o1,在数学、编程和科学问题的解决处理能力上取得了显著进步。该模型通过自我对弈强化学习(Self-play RL)和思维链(Chain of Thought, CoT)技术…

openssl的使用

1、编译 Github下载:https://github.com/openssl/openssl 官网下载:https://openssl-library.org/source/index.html 官网历史版本:https://www.openssl.org/source/old/ 1.1 Windows下编译 我的文章:OPC UA使用 Openssl库编译…

0基础带你入门Linux之使用

1.Ubuntu软件管理 回顾一下,我们之前使用su root切换到root模式,使用who 发现为什么显示的还是bd用户呢?为什么呢? 这个who是主要来查看的是我们登录的时候是以什么用户登录的 所以即使我们使用who进行查看的时候显示的还是bd用…

【浅水模型MATLAB】尝试复刻SCI论文中的溃坝流算例

【浅水模型MATLAB】尝试复刻SCI论文中的溃坝流算例 前言问题描述控制方程及数值方法浅水方程及其数值计算方法边界条件的实现 代码框架与关键代码模拟结果 更新于2024年9月17日 前言 这篇博客算是学习浅水方程,并利用MATLAB复刻Liang (2004)1中溃坝流算例的一个记录…

SysML图例-重症输液泵

SysML图中词汇 infusion pump 输液泵。替代传统的重力式吊瓶输液,达到更加精准和更加安全给药的目的。 syringe pump 注射泵。也称作微量输液泵,主要目的是对容量式输液泵在微量给药方面的一个补充。

ECMAScript与JavaScript的区别

目录 一、什么是ECMAScript? 二、什么是JavaScript? 三、ECMAScript与JavaScript的关系 3.1 ECMAScript规范版本 3.2 JavaScript的实现 四、ECMAScript与JavaScript的主要区别 4.1 规范与实现的区别 4.2 版本更新 4.3 环境支持 4.4 语言特性 五…

中秋期间互联网产品故障事件(晋江、115盘、阿里云盘)盘点

24年中秋期间,除了肆掠的“贝碧嘉”台风外,互联网故障bug事件也不少,趁着有空盘点下,可作为员工信息安全培训案例。 一:晋江文学城访问异常(基础环境故障类) 9月14日,“晋江崩了”冲…

【python设计模式3】创建型模式2

目录 抽象工厂模式 建造者模式 单例模式 创建型模式概述 抽象工厂模式 抽象工厂模式:定义一个工厂类的接口让工厂子类来创建一系列相关或者相互依赖的对象。相比工厂方法模式,抽象工厂模式中的每一个具体工厂都生产一套产品。下面是生产厂商生产一部手…

VSCode扩展连接虚拟机MySQL数据库

在虚拟机安装MySQL vscode通过ssh远程登录Ubuntu 在vscode终端运行以下命令。 sudo apt-get install mysql-server-5.7 用以下命令确认MySQL是否安装完成。 sudo mysql MySQL安装成功。 在VSCode安装SQL扩展 扩展名:MySQL Shell for VS Code。 安装完成后&am…

【2025】智慧居家养老服务平台的设计与实现、基于AI的居家养老服务平台、居家养老服务平台开发、智慧养老服务平台设计

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

55.【C语言】字符函数和字符串函数(strstr函数)

11.strstr函数 *简单使用 strstr: string string cplusplus的介绍 点我跳转 翻译: 函数 strstr const char * strstr ( const char * str1, const char * str2 ); 或另一个版本char * strstr ( char * str1, const char * str2 ); 寻找子字符串 返回指向第一次出现在字…

从零开始学PostgreSQL (十四):高级功能

目录 1. 简介 2. 视图 3. 外键 4. 事务 5. 窗口函数 6. 继承 7. 结论 简介 PostgreSQL是一个强大且开源的关系型数据库管理系统,以其稳定性、功能丰富性和对SQL标准的广泛支持而闻名。它不仅提供了传统的关系型数据库功能,如事务处理、外键约束和视图&am…

CISP-PTE CMS sqlgun靶场

sql靶场有个搜索框先点一下go,有回显说明存在漏洞 有个xss 然后在这里尝试sql注入 输入 -1 union select 1,2,3# 有回显可以查看数据库 然后查询数据库,用户 查询数据库的表名 查询它的数据这里admin用户的密码是md5加密 去解密看看 然后扫描ip目录发…

Zookeeper 3.8.4 安装和参数解析

安装 zookeeper 之前必须先安装 JDK,有关Linux环境JDK可以参考我以前写的博文 1、关于Linux服务器配置java环境遇到的问题 2、Linux环境安装openJDK 3、Centos7.3云服务器上安装Nginx、MySQL、JDK、Tomcat环境 文章目录 1. zookeeper 安装2. 参数解析 1. zookeeper …

03-Mac系统PyCharm主题设置

目录 1. 打开PyCharm窗口 2. Mac左上角点击PyCharm,点击Settings 3. 点击第一项Appearance& Behavior 4. 点击Appearance 5. 找到Theme进行设置 1. 打开PyCharm窗口 2. Mac左上角点击PyCharm,点击Settings 3. 点击第一项Appearance& Behavi…

物理感知扩散的 3D 分子生成模型 - PIDiff 评测

PIDiff 是一个针对蛋白质口袋特异性的、物理感知扩散的 3D 分子生成模型,通过考虑蛋白质-配体结合的物理化学原理来生成分子,在原理上,生成的分子可以实现蛋白-小分子的自由能最小。 一、背景介绍 PIDiff 来源于延世大学计算机科学系的 Sang…

Git 原理(提交对象)(结合图与案例)

Git 原理(提交对象) 这一块主要讲述下 Git 的原理。 在进行提交操作时,Git 会保存一个提交对象(commit object): 该提交对象会包含一个指向暂存内容快照的指针; 该提交对象还包含了作者的姓…

Java | Leetcode Java题解之第403题青蛙过河

题目&#xff1a; 题解&#xff1a; class Solution {public boolean canCross(int[] stones) {int n stones.length;boolean[][] dp new boolean[n][n];dp[0][0] true;for (int i 1; i < n; i) {if (stones[i] - stones[i - 1] > i) {return false;}}for (int i 1…

HAL库学习梳理——UART

笔者跟着B站铁头山羊视频学习 STM32-HAL库 开发教程。下面对HAL库有关UART课程知识和应用做一个梳理。 省流&#xff1a; uint8_t byteNumber 0x5a;uint8_t byteArray[] {0,1,2,3,4,5};char ch a;char *str "Hello word";HAL_UART_Transmit(&huart1,&by…