📜  keras 回调 - Python (1)

📅  最后修改于: 2023-12-03 14:43:39.275000             🧑  作者: Mango

Keras回调

简介

Keras回调是一组函数,它们在训练模型的过程中被称为“回调”。回调可以在模型的每个时期开始或结束时执行操作,如模型检查点,模型停止,学习率调度等。它们是Keras中构建复杂模型的有用工具。

常用回调
ModelCheckpoint

ModelCheckpoint是Keras中最常用的回调之一。它允许你保存你训练的模型权重,以便在训练时出现崩溃时可以重新开始。可以在训练过程中每个时期保存一次或仅在验证准确性改进时保存。

from keras.callbacks import ModelCheckpoint

filepath = "weights-improvement-{epoch:02d}-{val_accuracy:.2f}.hdf5"

checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

callbacks_list = [checkpoint]
EarlyStopping

EarlyStopping 是一个有用的回调,它可以在训练过程中监视验证准确性,并在多个周期中未改善时停止训练。这有助于防止过度拟合并提高模型的泛化能力。

from keras.callbacks import EarlyStopping

early_stop = EarlyStopping(monitor='val_loss', patience=3, verbose=1, mode='min')

callbacks_list = [early_stop]
ReduceLROnPlateau

ReduceLROnPlateau是一个用于动态调整学习率的回调。如果模型的性能在多个时期中没有改善,回调将减少学习率。这有助于让模型更好地训练,并避免在时期之间下降。

from keras.callbacks import ReduceLROnPlateau

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001)

callbacks_list = [reduce_lr]
自定义回调

Keras还允许开发人员创建自定义回调。这些回调可以执行任何操作,包括对输入数据进行在线增强,对训练过程进行可视化等。以下是一个简单的示例,该示例显示每个时期的训练和验证损失。

from keras.callbacks import Callback

class LossHistory(Callback):
    def __init__(self):
        self.train_losses = []
        self.val_losses = []

    def on_epoch_end(self, epoch, logs=None):
        self.train_losses.append(logs.get('loss'))
        self.val_losses.append(logs.get('val_loss'))

history = LossHistory()

model.fit(X_train, Y_train, validation_data=(X_test, Y_test), callbacks=[history])
总结

Keras回调是强大的工具,可以帮助您在模型训练过程中完成自动任务。回调可以执行各种操作,如检查点,早期停止和学习率调度,还可以让您自定义回调以执行任何其他自定义操作。了解这些回调的功能并使用它们有助于提高您的模型性能。