tf.Keras (tf-1.15)使用记录3-model.compile方法

model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。

注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()

1 基本用法

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

1) optimizer: 优化器

可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。

以下是可以使用的字符串参数:

  1. 'sgd': 随机梯度下降优化器
  2. 'adam': Adam 优化器
  3. 'rmsprop': RMSprop 优化器
  4. 'adagrad': Adagrad 优化器
  5. 'adadelta': Adadelta 优化器
  6. 'adamax': Adamax 优化器
  7. 'nadam': Nadam 优化器
  8. 'ftrl': Ftrl 优化器

需要注意的是:

  1. 这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。

  2. 使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
    
  3. ‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。

  4. 在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。

2) loss: 损失函数

用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。

以下是一些常用的字符串参数对应的损失函数:

  1. 'binary_crossentropy': 用于二分类问题的交叉熵损失。
  2. 'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。
  3. 'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。
  4. 'mean_squared_error''mse': 均方误差损失,用于回归问题。
  5. 'mean_absolute_error''mae': 平均绝对误差损失,用于回归问题。
  6. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  7. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题,对小差异不敏感。
  8. 'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。
  9. 'kullback_leibler_divergence''kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。
  10. 'hinge': 用于“最大间隔”分类问题的铰链损失。
  11. 'squared_hinge': 铰链损失的平方版本。
  12. 'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。

3) metrics: 评估指标列表,用于评估模型的性能

这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy''precision''recall' 等。

model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:

  1. 'accuracy''acc': 准确率,用于分类问题。
  2. 'binary_accuracy': 二分类准确率。
  3. 'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。
  4. 'sparse_categorical_accuracy': 多分类准确率,标签为整数。
  5. 'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。
  6. 'sparse_top_k_categorical_accuracy': 与 'top_k_categorical_accuracy' 类似,但适用于标签为整数的情况。
  7. 'mean_squared_error''mse': 均方误差,用于回归问题。
  8. 'mean_absolute_error''mae': 平均绝对误差,用于回归问题。
  9. 'mean_absolute_percentage_error''mape': 平均绝对百分比误差,用于回归问题。
  10. 'mean_squared_logarithmic_error''msle': 均方对数误差,用于回归问题。
  11. 'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。
  12. 'precision': 精确率,用于二分类或多标签分类问题。
  13. 'recall': 召回率,用于二分类或多标签分类问题。
  14. 'auc': 曲线下面积(Area Under the Curve),用于二分类问题。

使用示例:

# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])

对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。

from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])

2 高级用法

  • 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
  • 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
  • 使用多个损失函数和评估指标:

如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。

model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
  • 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.rhkb.cn/news/10974.html

如若内容造成侵权/违法违规/事实不符,请联系长河编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

创建前端项目的方法

目录 一、创建前端项目的方法 1.前提:安装Vue CLI 2.方式一:vue create项目名称 3.方式二:vue ui 二、Vue项目结构 三、修改Vue项目端口号的方法 一、创建前端项目的方法 1.前提:安装Vue CLI npm i vue/cli -g 2.方式一&…

(leetcode 213 打家劫舍ii)

代码随想录&#xff1a; 将一个线性数组换成两个线性数组&#xff08;去掉头&#xff0c;去掉尾&#xff09; 分别求两个线性数组的最大值 最后求这两个数组的最大值 代码随想录视频 #include<iostream> #include<vector> #include<algorithm> //nums:2,…

【Python】第七弹---Python基础进阶:深入字典操作与文件处理技巧

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、字典 1.1、字典是什么 1.2、创建字典 1.3、查找 key 1.4、新增/修改元素 1.5、删除元素 1.6、遍历…

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操 Janus-Pro-7B介绍 Janus-Pro-7B 是由 DeepSeek 开发的多模态 AI 模型&#xff0c;它在理解和生成方面取得了显著的进步。这意味着它不仅可以处理文本&#xff0c;还可以处理图像等其他模态的信息。 模型主要特点:Permalink…

BW AO/工作簿权限配置

场景&#xff1a; 按事业部配置工作簿权限&#xff1b; 1、创建用户 事务码&#xff1a;SU01&#xff0c;用户主数据的维护&#xff0c;可以创建、修改、删除、锁定、解锁、修改密码等 用户设置详情页 2、创建权限角色 用户的权限菜单是通过权限角色分配来实现的 2.1、自定…

jstat命令详解

jstat 用于监视虚拟机运行时状态信息的命令&#xff0c;它可以显示出虚拟机进程中的类装载、内存、垃圾收集、JIT 编译等运行数据。 命令的使用格式如下。 jstat [option] LVMID [interval] [count]各个参数详解&#xff1a; option&#xff1a;操作参数LVMID&#xff1a;本…

3.Spring-事务

一、隔离级别&#xff1a; 脏读&#xff1a; 一个事务访问到另外一个事务未提交的数据。 不可重复读&#xff1a; 事务内多次查询相同条件返回的结果不同。 幻读&#xff1a; 一个事务在前后两次查询同一个范围的时候&#xff0c;后一次查询看到了前一次查询没有看到的行。 二…

MYSQL--一条SQL执行的流程,分析MYSQL的架构

文章目录 第一步建立连接第二部解析 SQL第三步执行 sql预处理优化阶段执行阶段索引下推 执行一条select 语句中间会发生什么&#xff1f; 这个是对 mysql 架构的深入理解。 select * from product where id 1;对于mysql的架构分层: mysql 架构分成了 Server 层和存储引擎层&a…

ReentrantReadWriteLock源码分析

文章目录 概述一、状态位设计二、读锁三、锁降级机制四、写锁总结 概述 ReentrantReadWriteLock&#xff08;读写锁&#xff09;是对于ReentranLock&#xff08;可重入锁&#xff09;的一种改进&#xff0c;在可重入锁的基础上&#xff0c;进行了读写分离。适用于读多写少的场景…

51单片机开发:温度传感器

温度传感器DS18B20&#xff1a; 初始化时序图如下图所示&#xff1a; u8 ds18b20_init(void){ds18b20_reset();return ds18b20_check(); }void ds18b20_reset(void){DS18B20_PORT 0;delay_10us(75);DS18B20_PORT 1;delay_10us(2); }u8 ds18b20_check(void){u8 time_temp0;wh…

vue2项目(一)

项目介绍 电商前台项目 技术架构&#xff1a;vuewebpackvuexvue-routeraxiosless.. 封装通用组件登录注册token购物车支付项目性能优化 一、项目初始化 使用vue create projrct_vue2在命令行窗口创建项目 1.1、脚手架目录介绍 ├── node_modules:放置项目的依赖 ├──…

labelme_json_to_dataset ValueError: path is on mount ‘D:‘,start on C

这是你的labelme运行时label照片的盘和保存目的地址的盘不同都值得报错 labelme_json_to_dataset ValueError: path is on mount D:,start on C 只需要放一个盘但可以不放一个目录

物联网 STM32【源代码形式-使用以太网】连接OneNet IOT从云产品开发到底层MQTT实现,APP控制 【保姆级零基础搭建】

物联网&#xff08;IoT&#xff09;‌是指通过各种信息传感器、射频识别技术、全球定位系统、红外感应器等装置与技术&#xff0c;实时采集并连接任何需要监控、连接、互动的物体或过程&#xff0c;实现对物品和过程的智能化感知、识别和管理。物联网的核心功能包括数据采集与监…

无心剑七绝《深度求索》

七绝深度求索 深研妙理定乾坤 度世玄机启智门 求路千难兼万险 索萦华夏自为尊 2025年2月1日 平水韵十三元平韵 无心剑七绝《深度求索》以平水韵十三元平韵写成&#xff0c;意境深远&#xff0c;气势磅礴。诗中“深研妙理定乾坤”开篇点题&#xff0c;展现出对深奥道理的钻研与探…

Hot100之普通数组

53最大子数组和 题目 思路解析 我们用一个dp数组来收集我们从左往右&#xff0c;加起来的最大的和 也就是我们的节点不是负数&#xff0c;那我们直接收集就好了 如果是负数&#xff0c;我们就用Max&#xff08;&#xff09;比较是这个节点大还是当前节点大&#xff08;这个情…

如何利用天赋实现最大化的价值输出-补

原文&#xff1a; https://blog.csdn.net/ZhangRelay/article/details/145408621 ​​​​​​如何利用天赋实现最大化的价值输出-CSDN博客 如何利用天赋实现最大化的价值输出-CSDN博客 引用视频差异 第一段视频目标明确&#xff0c;建议也非常明确。 录制视频的人是主动性…

新能源算力战争:为什么AI大模型需要绿色数据中心?

新能源算力战争:为什么AI大模型需要绿色数据中心? 近年来,人工智能(AI)大模型的爆发式增长正在重塑全球科技产业的格局。以GPT-4、Gemini、Llama等为代表的千亿参数级模型,不仅需要海量数据训练,更依赖庞大的算力支撑。然而,这种算力的背后隐藏着一个日益严峻的挑战——…

Spring Boot 中的事件发布与监听:深入理解 ApplicationEventPublisher(附Demo)

目录 前言1. 基本知识2. Demo3. 实战代码 前言 &#x1f91f; 找工作&#xff0c;来万码优才&#xff1a;&#x1f449; #小程序://万码优才/r6rqmzDaXpYkJZF 基本的Java知识推荐阅读&#xff1a; java框架 零基础从入门到精通的学习路线 附开源项目面经等&#xff08;超全&am…

unity学习24:场景scene相关生成,加载,卸载,加载进度,异步加载场景等

目录 1 场景数量 SceneManager.sceneCount 2 直接代码生成新场景 SceneManager.CreateScene 3 场景的加载 3.1 用代码加载场景&#xff0c;仍然build setting里先加入配置 3.2 卸载场景 SceneManager.UnloadSceneAsync(); 3.3 同步加载场景 SceneManager.LoadScene 3.3.…

【Android】布局文件layout.xml文件使用控件属性android:layout_weight使布局较为美观,以RadioButton为例

目录 说明举例 说明 简单来说&#xff0c;android:layout_weight为当前控件按比例分配剩余空间。且单个控件该属性的具体数值不重要&#xff0c;而是多个控件的属性值之比发挥作用&#xff0c;例如有2个控件&#xff0c;各自的android:layout_weight的值设为0.5和0.5&#xff0…