Keras training with shuffled tf.data: if training is interrupted, how to continue training at last data iteration/order of last saved checkpoint
我正在用 keras
有时我的云实例会在一个纪元完成之前断开连接或崩溃,但
我想在训练另一个 epoch 之前完成对那个 epoch 中数据的训练(我有很长的 epoch),因此每个数据示例在每个 epoch 训练一次。
有没有办法获取数据的原始顺序,以及模型最后保存在数据中的位置?
到目前为止我发现了什么
看起来您可以通过设置种子在 .shuffle 中设置特定顺序。但是,洗牌只发生在缓冲区中,所以我不能 100% 确定设置种子是否会完美地重现订单。另外,我不确定这将如何与
即使我确实获得了训练顺序的副本,我也不确定如何在顺序中找到模型最后保存的位置,然后从该点开始训练。我必须得到的一个想法是手动遍历数据集,直到我到达它。虽然我不确定
为了从上次保存模型的位置获取步骤/批次编号,我可能可以将其记录在某个地方。
这些解决方案似乎是粗略的解决方法,我想知道 Keras 中是否有一些我可能忽略的功能可以帮助解决这个问题。
似乎没有内置的 keras 可以做到这一点,但如果我错了,请纠正我。
我的方法
调试
为了调试并确保以相同的顺序生成 epoch 和 batch,我们需要一种方法来打印每个 epoch-batch 中数据点的获取方式。这很棘手,因此出于调试目的,我将使用回归问题并将基本事实作为序号。然后我可以有一个自定义损失,我可以在其中打印基本事实并使用户的顺序正确。
模型和数据
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 | import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import numpy as np import keras.backend as K # Data x_train = np.random.randn(15, 10).astype("float32") y_train = np.arange(15).astype("float32") # Custom MSE looss just to track the order in which data is picked up def my_mse(y_true, y_pred): tf.print(tf.keras.backend.flatten(y_true)) loss = K.square(y_pred - y_true) loss = K.sum(loss, axis=1) return loss # Model def get_model(): inputs = keras.Input(shape=(10)) outputs = layers.Dense(1, activation="linear")(inputs) model = keras.Model(inputs=inputs, outputs=outputs) model.compile( optimizer="rmsprop", loss=my_mse, ) return model |
数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8) epochs = 2 print ("Runs 1") for e in range(epochs): for i, (x, y) in enumerate(train_dataset): print (e, i, y) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(8) print ("Runs 2") for e in range(epochs): for i, (x, y) in enumerate(train_dataset): print (e, i, y) |
输出:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | Runs 1 0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32) 1 tf.Tensor([ 6. 11. 10. 14. 9. 12. 13.], shape=(7,), dtype=float32) 2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32) 3 tf.Tensor([13. 10. 0. 14. 6. 11. 12.], shape=(7,), dtype=float32) 4 tf.Tensor([ 0. 1. 5. 6. 9. 3. 7. 14.], shape=(8,), dtype=float32) 5 tf.Tensor([13. 8. 4. 10. 2. 12. 11.], shape=(7,), dtype=float32) Runs 2 0 tf.Tensor([1. 3. 5. 7. 4. 0. 8. 2.], shape=(8,), dtype=float32) 1 tf.Tensor([ 6. 11. 10. 14. 9. 12. 13.], shape=(7,), dtype=float32) 2 tf.Tensor([4. 2. 5. 8. 1. 9. 7. 3.], shape=(8,), dtype=float32) 3 tf.Tensor([13. 10. 0. 14. 6. 11. 12.], shape=(7,), dtype=float32) 4 tf.Tensor([ 0. 1. 5. 6. 9. 3. 7. 14.], shape=(8,), dtype=float32) 5 tf.Tensor([13. 8. 4. 10. 2. 12. 11.], shape=(7,), dtype=float32) |
是的,使用种子复制订单。
现在让我们编写一个方法将数据集转发到某个时期和批次组合
1 2 3 4 5 6 7 8 9 10 | def forward(dataset, n=None): if not n: return dataset i = 0 while True: for _ in dataset: i += 1 if i == n: return dataset |
测试用例:
让我们正常运行并观察顺序
从一开始的数据
1 2 3 4 5 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None) model = get_model() model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False) |
输出:
1 2 3 4 5 6 7 8 9 10 11 12 | [7 3 6 10] [11 0 1 2] [8 14 9 13] [12 5 4] [5 8 6 3] [1 12 10 9] [2 11 0 4] [14 13 7] [2 3 0 10] [4 1 13 6] [8 7 14 11] [12 5 9] |
来自 Dataset 第 n 个状态的数据
让我们的数据集进行第四次迭代并运行训练
1 2 3 4 5 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4) model = get_model() model.fit(train_dataset, epochs=3, verbose=0, workers=4, shuffle=False) |
输出:
1 2 3 4 5 6 7 8 | [5 8 6 3] [1 12 10 9] [2 11 0 4] [14 13 7] [2 3 0 10] [4 1 13 6] [8 7 14 11] [12 5 9] |
很好,现在我们知道如何正确转发数据集了。现在让我们编写回调来跟踪当前的迭代次数:
跟踪迭代的自定义回调(epoch-batch 组合)
现在我们需要确定模型被检查指向的时期和批次组合。如果我们有这些信息,我们可以加载最后一个检查点模型并将我们的数据集转发到它的批次和时期组合并继续训练。我们将使用回调
来做到这一点
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | class MyCustomCallback(tf.keras.callbacks.ModelCheckpoint, keras.callbacks.Callback): def __init__(self, the_id=0, **args): self.the_id = the_id self.epoch = 0 super().__init__(**args) def _save_model(self, epoch, logs): logs['the_id'] = self.the_id super()._save_model(epoch, logs) def on_batch_end(self, batch, logs={}): self.the_id += 1 super().on_batch_end(batch, logs) checkpoint_filepath = 'checkpoint-{the_id}' model_checkpoint_callback = MyCustomCallback( filepath=checkpoint_filepath, save_freq=2, save_best_only=False) model = get_model() train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), None) model.fit(train_dataset, epochs=5, verbose=0, callbacks=[model_checkpoint_callback], workers=4, shuffle=False) |
输出:
1 2 3 4 5 6 7 8 9 10 11 12 | [7 3 6 10] [11 0 1 2] [8 14 9 13] [12 5 4] [5 8 6 3] [1 12 10 9] [2 11 0 4] [14 13 7] [2 3 0 10] [4 1 13 6] [8 7 14 11] [12 5 9] |
我们每两批检查一次。所以让我们假设它崩溃并且最后一个检查点是
1 2 3 4 5 | train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) train_dataset = forward(train_dataset.shuffle(buffer_size=8, reshuffle_each_iteration=True, seed=0).batch(4), 4) model = get_model() model.fit(train_dataset, epochs=2, verbose=0, workers=4, shuffle=False) |
输出:
1 2 3 4 5 6 7 8 | [5 8 6 3] [1 12 10 9] [2 11 0 4] [14 13 7] [2 3 0 10] [4 1 13 6] [8 7 14 11] [12 5 9] |
我想您想恢复随机播放顺序以避免在此时期内重复某些样本。
根据未完成时期的洗牌描述,您的模型只能访问数据集中的第一个 current_step_number shuffle_buffer_size 样本。
因此,当您恢复训练时,如果您知道处理了多少步,则可以跳过此步骤跳过 shuffle_buffer_size 步,然后您将继续在以下样本上进行训练,这在当前 epoch 内尚未观察到。
请注意,在此时期根本不会观察到来自数据集第一部分的一些随机 shuffle_buffer_size 样本。正如你所说,你的时代很长,所以,可能你有很多数据,所以丢失 shuffle_buffer_size 样本对你来说应该不是问题。
所以在保存检查点的过程中也要保存步数,然后在加载检查点后创建带有跳过步骤的数据集副本(使用 dataset.skip),然后将 model.fit 与这个较小的数据集一起使用一个时期(以完成当前时期),然后继续以平常的方式进行训练。