浅析Estimator、model_fn与EstimatorSpec

参考阅读:https://zhuanlan.zhihu.com/p/74857888

文章目录

  • 综合对比
      • Estimator
      • model_fn
      • EstimatorSpec
      • 关系
      • 总结
  • Estimator
      • 主要功能
      • 构造函数参数
      • 示例用法
      • 小结
  • model_fn
  • EstimatorSpec
      • 字段解释
      • 解释代码
      • 用途

综合对比

Estimatormodel_fnEstimatorSpec 是 TensorFlow 中用于构建、训练和评估模型的三个核心组件。它们之间的关系可以总结如下:

Estimator

  • 定义: Estimator 是 TensorFlow 提供的高层 API,用于简化和标准化模型的训练、评估和预测。
  • 功能:
    • 封装训练、评估和预测的逻辑。
    • 管理检查点、日志记录和模型保存。
    • 提供一致的接口来处理不同类型的模型。
  • 参数:
    • model_fn: 定义模型的函数。
    • model_dir: 模型保存目录。
    • config: 执行环境的配置信息。
    • params: 超参数字典。
    • warm_start_from: 热启动配置。

model_fn

  • 定义: model_fn 是一个函数,定义了模型的结构和行为。它由 Estimator 在训练、评估和预测时调用。
  • 功能:
    • 构建模型的计算图。
    • 根据运行模式(TRAIN、EVAL、PREDICT)返回不同的操作。
    • 接受特征、标签、模式、超参数和配置信息作为输入。
  • 返回值:
    • 返回一个 EstimatorSpec 对象,定义了模型在不同模式下的行为。

EstimatorSpec

  • 定义: EstimatorSpec 是一个对象,包含了模型在训练、评估和预测模式下的所有必要信息。
  • 功能:
    • 定义模型的预测、损失、训练操作和评估指标。
    • 提供一致的接口,使 Estimator 能够在不同模式下正确运行模型。
  • 字段:
    • mode: 运行模式(TRAIN、EVAL、PREDICT)。
    • predictions: 预测结果。
    • loss: 损失值。
    • train_op: 训练操作。
    • eval_metric_ops: 评估指标操作。
    • export_outputs: 导出输出。
    • training_chief_hooks, training_hooks, scaffold, evaluation_hooks, prediction_hooks: 各种钩子和脚手架对象,用于在不同阶段执行自定义操作。

关系

  1. Estimator 使用 model_fn:

    • Estimator 调用 model_fn 来构建模型的计算图并定义其行为。
    • model_fn 接受特征、标签、模式、超参数和配置信息,并返回一个 EstimatorSpec 对象。
  2. model_fn 返回 EstimatorSpec:

    • model_fn 根据当前的运行模式(TRAIN、EVAL、PREDICT)创建并返回一个 EstimatorSpec 对象。
    • EstimatorSpec 对象包含了模型在当前模式下所需的所有操作和输出。
  3. Estimator 使用 EstimatorSpec:

    • Estimator 使用 EstimatorSpec 中定义的操作来执行训练、评估和预测。
    • 根据 EstimatorSpec 中的信息,Estimator 知道如何处理模型的预测、损失计算和训练步骤。

总结

  • Estimator 是高层接口,用于管理和运行模型。
  • model_fn 是用户定义的函数,用于构建模型的计算图并返回 EstimatorSpec
  • EstimatorSpec 定义了模型在不同模式下的行为,由 model_fn 返回,并由 Estimator 使用。

Estimator

Estimator 是 TensorFlow 提供的一个高层 API,用于简化模型的训练和评估。它封装了一个模型,模型通过 model_fn 指定。Estimator 负责处理训练、评估和预测所需的所有操作,并将结果输出到指定的目录。

主要功能

  1. 模型训练、评估和预测: Estimator 封装了这些操作,简化了模型的开发和部署过程。
  2. 模型保存和恢复: 所有输出(如检查点、事件文件等)都写入 model_dir,或其子目录。这样可以方便地保存和恢复模型。
  3. 运行配置: 通过 config 参数,Estimator 可以获取有关执行环境的信息,并将其传递给 model_fn
  4. 超参数传递: 通过 params 参数,Estimator 可以将超参数传递给 model_fn 和输入函数。

构造函数参数

  • model_fn: 模型函数,定义了如何构建模型。它接受以下参数:

    • features: 从 input_fn 返回的特征,通常是 TensorTensor 字典。
    • labels: 从 input_fn 返回的标签,通常是 TensorTensor 字典。在预测模式下,labelsNone
    • mode: 运行模式,可以是 TRAINEVALPREDICT
    • params: 超参数字典,包含传递给 Estimator 的超参数。
    • config: RunConfig 对象,包含执行环境的配置信息。
  • model_dir: 模型参数、图等的保存目录,也可以用于从目录加载检查点以继续训练之前保存的模型。

  • config: RunConfig 配置对象,包含执行环境的配置信息。如果model_fn函数也定义config这个变量,则会将config传给model_fn。

  • params: 超参数字典,包含传递给 model_fn 的超参数。

  • warm_start_from: 检查点或 SavedModel 的文件路径,用于热启动,或一个 WarmStartSettings 对象以完全配置热启动。

示例用法

  1. 创建一个 Estimator 实例

    estimator = tf.estimator.DNNClassifier(feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],hidden_units=[1024, 512, 256],warm_start_from="/path/to/checkpoint/dir"
    )
    
  2. 定义 model_fn

    def my_model_fn(features, labels, mode, params):# 构建模型logits = build_model(features, mode, params)predictions = {'classes': tf.argmax(input=logits, axis=1),'probabilities': tf.nn.softmax(logits)}# PREDICT 模式if mode == tf.estimator.ModeKeys.PREDICT:return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)# 计算损失loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)# 训练操作if mode == tf.estimator.ModeKeys.TRAIN:optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)# 评估指标eval_metric_ops = {'accuracy': tf.metrics.accuracy(labels=labels, predictions=predictions['classes'])}return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
    
  3. 使用 Estimator 进行训练、评估和预测

    # 训练
    estimator.train(input_fn=train_input_fn, steps=1000)# 评估
    eval_result = estimator.evaluate(input_fn=eval_input_fn)
    print(eval_result)# 预测
    predictions = estimator.predict(input_fn=predict_input_fn)
    for pred in predictions:print(pred)
    

小结

Estimator 提供了一种结构化的方法来定义和管理 TensorFlow 模型,使得模型的训练、评估和预测更加方便和标准化。它通过 model_fn 将模型的构建与训练、评估和预测逻辑分离,并且通过配置和参数化提供了灵活性。

model_fn

输入:

  • features: 从 input_fn 返回的特征,通常是 TensorTensor 字典。
  • labels: 从 input_fn 返回的标签,通常是 TensorTensor 字典。在预测模式下,labelsNone
  • mode: 运行模式,可以是 TRAINEVALPREDICT
  • params: 超参数字典,包含传递给 Estimator 的超参数。
  • config: RunConfig 对象,包含执行环境的配置信息。

返回值:
一个EstimatorSpec

前两个参数是从输入函数中返回的特征和标签批次;也就是说,features 和 labels 是模型将使用的数据。

params 是一个字典,它可以传入许多参数用来构建网络或者定义训练方式等。例如通过设置params[‘n_classes’]来定义最终输出节点的个数等。
config 通常用来控制checkpoint或者分布式什么,这里不深入研究。
mode 参数表示调用程序是请求训练、评估还是预测,分别通过tf.estimator.ModeKeys.TRAIN / EVAL / PREDICT 来定义。另外通过观察DNNClassifier的源代码可以看到,mode这个参数并不用手动传入,因为Estimator会自动调整。例如当你调用estimator.train(…)的时候,mode则会被赋值tf.estimator.ModeKeys.TRAIN。

模型有训练,验证和测试三种阶段,而且对于不同模式,对数据有不同的处理方式。例如在训练阶段,我们需要将数据喂给模型,模型基于输入数据给出预测值,然后我们在通过预测值和真实值计算出loss,最后用loss更新网络参数,而在评估阶段,我们则不需要反向传播更新网络参数,换句话说,model_fn需要对三种模式设置三套代码

EstimatorSpec

collections.namedtuple 是 Python 标准库中的一个函数,用于创建不可变的、具名的元组(named tuple)。这些具名元组可以像类一样使用,有字段名称,使代码更具可读性和可维护性。

在这段代码中,collections.namedtuple 被用来创建一个名为 EstimatorSpec 的具名元组,它包含了一组用于定义模型在不同模式下行为的字段。以下是每个字段的解释:

字段解释

  1. mode: 模式,表示当前的运行模式,可以是训练(TRAIN)、评估(EVAL)或预测(PREDICT)模式。
  2. predictions: 预测值,可以是一个 TensorTensor 字典,用于预测模式下输出结果。
  3. loss: 损失值,一个标量 Tensor,表示模型的损失,用于训练和评估模式。
  4. train_op: 训练操作,表示在训练模式下执行的操作(通常是优化步骤)。
  5. eval_metric_ops: 评估指标操作,是一个字典,包含评估模式下的度量结果。
  6. export_outputs: 导出输出,是一个字典,定义了模型在导出为 SavedModel 时的输出签名。
  7. training_chief_hooks: 主训练钩子,是一个迭代器,包含在主 worker 上运行的 SessionRunHook 对象。
  8. training_hooks: 训练钩子,是一个迭代器,包含在所有 worker 上运行的 SessionRunHook 对象。
  9. scaffold: 脚手架,是一个 tf.train.Scaffold 对象,用于设置初始化、保存和恢复操作。
  10. evaluation_hooks: 评估钩子,是一个迭代器,包含在评估过程中运行的 SessionRunHook 对象。
  11. prediction_hooks: 预测钩子,是一个迭代器,包含在预测过程中运行的 SessionRunHook 对象。

解释代码

collections.namedtuple('EstimatorSpec', ['mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops','export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold','evaluation_hooks', 'prediction_hooks'
])

这行代码创建了一个名为 EstimatorSpec 的具名元组类,它包含了上述的这些字段。EstimatorSpec 类可以用于存储和传递这些字段的值,使得在模型函数(model_fn)中可以方便地定义和返回这些值。

用途

EstimatorSpec 主要用于 TensorFlow 的 Estimator API 中,以统一的方式定义模型的各个组成部分。通过使用 EstimatorSpec,可以确保模型在不同模式下的行为是一致且正确的。例如:

  • 在训练模式下,必须提供 losstrain_op
  • 在评估模式下,必须提供 loss
  • 在预测模式下,必须提供 predictions

使用 EstimatorSpec,可以更简洁和清晰地定义模型的各个部分,并且通过具名元组的方式,使代码更加可读和易于维护。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/367064.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

207.贪心算法:最大子数组和(力扣)

代码展示 class Solution { public:int maxSubArray(vector<int>& nums) {int result INT_MIN; // 初始化结果为最小可能的整数int sum 0; // 初始化当前子数组和为0// 遍历数组中的每一个元素for (int i 0; i < nums.size(); i){sum nums[i]; //…

昇思25天学习打卡营第9天|MindSpore-Vision Transformer图像分类

Vision Transformer图像分类 Vision Transformer(ViT)简介 近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前…

LinkedList底层原理

LinkedList特有方法 源码分析

使用工业自动化的功能块实现大语言模型应用

大语言模型无所不能&#xff1f; 以chatGPT为代表的大语言模型横空出世&#xff0c;在世界范围内掀起了一场AI革命。给人的感觉似乎大模型语言无所不能。它不仅能够生成文章&#xff0c;图片和视频&#xff0c;能够翻译文章&#xff0c;分析科学和医疗数据&#xff0c;甚至可以…

前端git约定式规范化提交-commitizen

当使用commitizen进行代码提交时&#xff0c;commitizen会提示你在提交代码时填写所必填的提交字段信息内容。 1、全局安装commitizen npm install -g commitizen4.2.4 2、安装并配置 cz-customizeable 插件 2.1 使用 npm 下载 cz-customizeable npm i cz-customizeable6.…

低代码组件扩展方案在复杂业务场景下的设计与实践

组件是爱速搭的前端页面可视化模块的核心能力之一&#xff0c;它将前端研发人员从无休止的页面样式微调和分辨率兼容工作中解放了出来。 目前&#xff0c;爱速搭通过内置的上百种功能组件&#xff08;120&#xff09;&#xff0c;基本可以覆盖大部分中后台页面的可视化设计场景…

软件鉴定测试的工作内容是什么?专业软件鉴定测试报告获取指南

软件鉴定测试是指对软件产品进行全面的检测和评估&#xff0c;以验证其是否符合规定的标准和要求。通过测试&#xff0c;能够发现软件中存在的问题和缺陷&#xff0c;并提供相应的改进建议。在不同的测试阶段&#xff0c;使用不同的测试方法和工具&#xff0c;包括功能测试、性…

数据分析如何在企业中发挥价值

数据分析如何在企业中发挥价值 数据分析的目的是什么为什么怎么做做什么 思考问题流程确认问题拆解问题量化分析 分析数据流程收集数据处理数据制作图表 全流程 数据分析的目的 是什么 通过数据量化企业当前的经营现状或业务事实&#xff0c;将业务细节转换为具体数据&#xf…

爬虫cookie是什么意思

“爬虫 cookie”指的是网络爬虫在访问网站时所使用的cookie&#xff0c;网络爬虫是一种自动化程序&#xff0c;用于在互联网上收集信息并进行索引&#xff0c;这些信息可以用于搜索引擎、数据分析或其他目的。 本教程操作系统&#xff1a;Windows10系统、Dell G3电脑。 “爬虫…

数据库取出来的日期格式是数组格式,序列化日期格式

序列化前&#xff0c;如图所示&#xff1a; 解决方式&#xff0c;序列化日期&#xff08;localdatetime&#xff09;格式 步骤一、添加序列化类 package com.abliner.test.common.configure;import com.alibaba.fastjson.serializer.JSONSerializer; import com.alibaba.fas…

Python编写简单爬虫

文章目录 Python编写简单爬虫安装必要的库编写爬虫代码解析和存储数据注意事项 Python编写简单爬虫 安装必要的库 在开始编写爬虫之前&#xff0c;你需要安装一些必要的库。我们将使用requests库来发送HTTP请求&#xff0c;使用BeautifulSoup库来解析HTML内容。你可以使用以下…

fiddler抓https包

1&#xff0c;安装fiddler省略 2&#xff0c;下载证书步骤&#xff1a;tools-options-https 点击确认&#xff0c;点击OK&#xff0c;点击是 把证书安装到谷歌浏览器上步骤&#xff1a;点击谷歌浏览器右上角的设置&#xff0c;在搜索框中搜索证书&#xff0c;点击“证书管理”…

win10下Python的安装和卸载

前言 之前电脑上安装了python3.9版本&#xff0c;因为工作需要使用3.6版本的Python&#xff0c;需要将3.9版本卸载&#xff0c;重新安装3.6版本。下面就是具体的操作步骤: 1. 卸载 在我的电脑中搜索到3.9版本的安装文件&#xff0c;如下图&#xff1a; 双击该应用程序&#xf…

DevOps认证是什么?DevOps工具介绍

DevOps 这个词是由Development&#xff08;开发&#xff09; 和 Operations&#xff08;运维&#xff09;组合起来的&#xff0c;你可以把它理解成为一种让开发团队和运维团队紧密合作的方法。 DevOps从2009年诞生到现在已经14年多了&#xff0c;一开始大家还在摸索&#xff0…

马斯克宣布xAI将在8月份推出Grok-2大模型 预计年底推出Grok-3

在今年内&#xff0c;由特斯拉创始人马斯克创立的人工智能初创公司xAI将推出两款重要产品Grok-2和Grok-3。马斯克在社交平台上透露了这一消息&#xff0c;其中Grok-2预计在今年8月份面世&#xff0c;而Grok-3则计划于年底前亮相。 除此之外&#xff0c;马斯克还表示&#xff0c…

WLAN的WPA3安全技术

Wi-Fi安全加密的演进下图所示&#xff0c;当前最新的加密方式是WPA3。WPA3对现有网络提供了全方位的安全防护&#xff0c;增强了公共网络、家庭网络和802.1X企业网的安全性。 WPA3的核心为对等实体同时验证方式(Simultaneous Authentication of Equals, SAE)&#xff0c;即通信…

Android AlertDialog对话框

目录 AlertDialog对话框普通对话框单选框多选框自定义框 AlertDialog对话框 部分节选自博主编《Android应用开发项目式教程》&#xff08;机械工业出版社&#xff09;2024.6 在Android中&#xff0c;AlertDialog弹出对话框用于显示一些重要信息或者需要用户交互的内容。 弹出…

双目摄像头测距

Opencv双目校正函数 stereoRectify 详解 参数说明&#xff1a; 输入参数&#xff1a; cameraMatrix1&#xff1a;左目相机内参矩阵 distCoeffs1&#xff1a;左目相机畸变参数 cameraMatrix2&#xff1a;右目相机内参矩阵 distCoeffs2&#xff1a;右目相机畸变参数 imageSize&…

使用 ADB 查看 Android 设备的 CPU 使用率(详解)

在 Android 开发和调试过程中&#xff0c;监控设备的性能数据至关重要。CPU 使用率是一个关键的性能指标&#xff0c;它能够帮助开发者识别应用的性能瓶颈和优化机会。本文将详细介绍如何使用 Android Debug Bridge (ADB) 查看设备的 CPU 使用率&#xff0c;并解释终端上各个参…

LLM指令微调Prompt的最佳实践(二):Prompt迭代优化

文章目录 1. 前言2. Prompt定义3. 迭代优化——以产品说明书举例3.1 产品说明书3.2 初始Prompt3.3 优化1: 添加长度限制3.4 优化2: 细节纠错3.5 优化3: 添加表格 4. 总结5. 参考 1. 前言 前情提要&#xff1a; 《LLM指令微调Prompt的最佳实践&#xff08;一&#xff09;&#…