model.compile
是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。
注意事项: 在开始训练(调用 model.fit
)之前,必须先调用 model.compile()
。
1 基本用法
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
1) optimizer
: 优化器
可以是预定义优化器的字符串(如 'adam'
, 'sgd'
等),也可以是 tf.keras.optimizers
下的优化器实例。优化器负责调整模型的权重以最小化损失函数。
以下是可以使用的字符串参数:
'sgd'
: 随机梯度下降优化器'adam'
: Adam 优化器'rmsprop'
: RMSprop 优化器'adagrad'
: Adagrad 优化器'adadelta'
: Adadelta 优化器'adamax'
: Adamax 优化器'nadam'
: Nadam 优化器'ftrl'
: Ftrl 优化器
需要注意的是:
-
这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。
-
使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
-
‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。
-
在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。
2) loss
: 损失函数
用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses
下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。
以下是一些常用的字符串参数对应的损失函数:
'binary_crossentropy'
: 用于二分类问题的交叉熵损失。'categorical_crossentropy'
: 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。'sparse_categorical_crossentropy'
: 用于多分类问题的交叉熵损失,标签为整数。'mean_squared_error'
或'mse'
: 均方误差损失,用于回归问题。'mean_absolute_error'
或'mae'
: 平均绝对误差损失,用于回归问题。'mean_absolute_percentage_error'
或'mape'
: 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'
或'msle'
: 均方对数误差,用于回归问题,对小差异不敏感。'poisson'
: 泊松损失,适用于计数问题或其他泊松分布问题。'kullback_leibler_divergence'
或'kld'
: Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。'hinge'
: 用于“最大间隔”分类问题的铰链损失。'squared_hinge'
: 铰链损失的平方版本。'logcosh'
: 对数双曲余弦损失,用于回归问题,对异常值不敏感。
3) metrics
: 评估指标列表,用于评估模型的性能
这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy'
、'precision'
、'recall'
等。
在 model.compile()
方法中,metrics
参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:
'accuracy'
或'acc'
: 准确率,用于分类问题。'binary_accuracy'
: 二分类准确率。'categorical_accuracy'
: 多分类准确率,要求标签为 one-hot 编码。'sparse_categorical_accuracy'
: 多分类准确率,标签为整数。'top_k_categorical_accuracy'
: Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。'sparse_top_k_categorical_accuracy'
: 与'top_k_categorical_accuracy'
类似,但适用于标签为整数的情况。'mean_squared_error'
或'mse'
: 均方误差,用于回归问题。'mean_absolute_error'
或'mae'
: 平均绝对误差,用于回归问题。'mean_absolute_percentage_error'
或'mape'
: 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'
或'msle'
: 均方对数误差,用于回归问题。'cosine_similarity'
: 余弦相似度,用于回归问题或多标签分类问题。'precision'
: 精确率,用于二分类或多标签分类问题。'recall'
: 召回率,用于二分类或多标签分类问题。'auc'
: 曲线下面积(Area Under the Curve),用于二分类问题。
使用示例:
# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])
对于一些特定的指标(如 'precision'
, 'recall'
, 'auc'
等),可能需要使用 tf.keras.metrics
下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。
from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])
2 高级用法
- 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
- 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
- 使用多个损失函数和评估指标:
如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。
model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
- 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])