关于python:Tensorflow:如何保存/恢复模型?

Tensorflow: how to save/restore a model?

在Tensorflow中训练模型后:

  • 你如何保存训练有素的模型?
  • 你以后如何恢复这个保存的模型?

  • 我正在改进我的答案,添加更多有关保存和恢复模型的详细信息。

    在(和之后)Tensorflow版本0.11:

    保存模型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    import tensorflow as tf

    #Prepare to feed input, i.e. feed_dict and placeholders
    w1 = tf.placeholder("float", name="w1")
    w2 = tf.placeholder("float", name="w2")
    b1= tf.Variable(2.0,name="bias")
    feed_dict ={w1:4,w2:8}

    #Define a test operation that we will restore
    w3 = tf.add(w1,w2)
    w4 = tf.multiply(w3,b1,name="op_to_restore")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    #Create a saver object which will save all the variables
    saver = tf.train.Saver()

    #Run the operation by feeding input
    print sess.run(w4,feed_dict)
    #Prints 24 which is sum of (w1+w2)*b1

    #Now, save the graph
    saver.save(sess, 'my_test_model',global_step=1000)

    恢复模型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    import tensorflow as tf

    sess=tf.Session()    
    #First let's load meta graph and restore weights
    saver = tf.train.import_meta_graph('my_test_model-1000.meta')
    saver.restore(sess,tf.train.latest_checkpoint('./'))


    # Access saved Variables directly
    print(sess.run('bias:0'))
    # This will print 2, which is the value of bias that we saved


    # Now, let's access and create placeholders variables and
    # create feed-dict to feed new data

    graph = tf.get_default_graph()
    w1 = graph.get_tensor_by_name("w1:0")
    w2 = graph.get_tensor_by_name("w2:0")
    feed_dict ={w1:13.0,w2:17.0}

    #Now, access the op that you want to run.
    op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

    print sess.run(op_to_restore,feed_dict)
    #This will print 60 which is calculated

    这里和一些更高级的用例已经在这里得到了很好的解释。

    一个快速完整的教程,用于保存和恢复Tensorflow模型


    在(及之后)TensorFlow版本0.11.0RC1中,您可以根据https://www.tensorflow.org/programmers_guide/meta_graph调用tf.train.export_meta_graphtf.train.import_meta_graph直接保存和恢复模型。

    保存模型

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta

    恢复模型

    1
    2
    3
    4
    5
    6
    7
    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)


    对于TensorFlow版本<0.11.0RC1:

    保存的检查点包含模型中Variable的值,而不是模型/图形本身,这意味着恢复检查点时图形应该相同。

    这是一个线性回归的例子,其中有一个训练循环可以保存变量检查点,还有一个评估部分可以恢复先前运行中保存的变量并计算预测。当然,如果您愿意,还可以恢复变量并继续训练。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)

    w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
    b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
    y_hat = tf.add(b, tf.matmul(x, w))

    ...more setup for optimization and what not...

    saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        if FLAGS.train:
            for i in xrange(FLAGS.training_steps):
                ...training loop...
                if (i + 1) % FLAGS.checkpoint_steps == 0:
                    saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                               global_step=i+1)
        else:
            # Here's where you're restoring the variables w and b.
            # Note that the graph is exactly as it was when the variables were
            # saved in a prior training run.
            ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
            else:
                ...no checkpoint found...

            # Now you can run the model to get predictions
            batch_x = ...load some data...
            predictions = sess.run(y_hat, feed_dict={x: batch_x})

    以下是Variable的文档,其中包括保存和恢复。以下是Saver的文档。


    文件

    他们构建了一个详尽而有用的教程 - > https://www.tensorflow.org/guide/saved_model

    来自文档:

    保存

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    # Create some variables.
    v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
    v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

    inc_v1 = v1.assign(v1+1)
    dec_v2 = v2.assign(v2-1)

    # Add an op to initialize the variables.
    init_op = tf.global_variables_initializer()

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # Later, launch the model, initialize the variables, do some work, and save the
    # variables to disk.
    with tf.Session() as sess:
      sess.run(init_op)
      # Do some work with the model.
      inc_v1.op.run()
      dec_v2.op.run()
      # Save the variables to disk.
      save_path = saver.save(sess,"/tmp/model.ckpt")
      print("Model saved in path: %s" % save_path)

    恢复

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    tf.reset_default_graph()

    # Create some variables.
    v1 = tf.get_variable("v1", shape=[3])
    v2 = tf.get_variable("v2", shape=[5])

    # Add ops to save and restore all the variables.
    saver = tf.train.Saver()

    # Later, launch the model, use the saver to restore variables from disk, and
    # do some work with the model.
    with tf.Session() as sess:
      # Restore variables from disk.
      saver.restore(sess,"/tmp/model.ckpt")
      print("Model restored.")
      # Check the values of the variables
      print("v1 : %s" % v1.eval())
      print("v2 : %s" % v2.eval())

    Tensorflow 2

    这仍然是测试版所以我现在建议反对。如果您仍想沿着这条路前进,请参阅tf.saved_model使用指南

    Tensorflow <2 simple_save

    许多好的答案,为了完整性,我将加上我的2美分:simple_save。也是使用tf.data.Dataset API的独立代码示例。

    Python 3; Tensorflow 1.14

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import tensorflow as tf
    from tensorflow.saved_model import tag_constants

    with tf.Graph().as_default():
        with tf.Session() as sess:
            ...

            # Saving
            inputs = {
               "batch_size_placeholder": batch_size_placeholder,
               "features_placeholder": features_placeholder,
               "labels_placeholder": labels_placeholder,
            }
            outputs = {"prediction": model_output}
            tf.saved_model.simple_save(
                sess, 'path/to/your/location/', inputs, outputs
            )

    恢复:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    graph = tf.Graph()
    with restored_graph.as_default():
        with tf.Session() as sess:
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                'path/to/your/location/',
            )
            batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
            features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
            labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
            prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

            sess.run(prediction, feed_dict={
                batch_size_placeholder: some_value,
                features_placeholder: some_other_value,
                labels_placeholder: another_value
            })

    独立的例子

    原创博文

    以下代码为演示生成随机数据。

  • 我们首先创建占位符。他们将在运行时保存数据。从它们中,我们创建Dataset,然后创建Iterator。我们得到迭代器生成的张量,称为input_tensor,它将作为我们模型的输入。
  • 模型本身由input_tensor构建:基于GRU的双向RNN,后跟密集分类器。因为为什么不呢。
  • 损失是softmax_cross_entropy_with_logits,用Adam优化。在2个时期(每个2批)之后,我们用tf.saved_model.simple_save保存"训练"的模型。如果按原样运行代码,则模型将保存在当前工作目录中名为simple/的文件夹中。
  • 在新图表中,我们然后使用tf.saved_model.loader.load恢复保存的模型。我们使用graph.get_operation_by_name获取占位符和logits,使用graph.get_operation_by_name获取Iterator初始化操作。
  • 最后,我们对数据集中的两个批次进行推断,并检查保存和恢复的模型是否产生相同的值。他们是这样!
  • 码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    168
    169
    170
    171
    172
    import os
    import shutil
    import numpy as np
    import tensorflow as tf
    from tensorflow.python.saved_model import tag_constants


    def model(graph, input_tensor):
       """Create the model which consists of
        a bidirectional rnn (GRU(10)) followed by a dense classifier

        Args:
            graph (tf.Graph): Tensors' graph
            input_tensor (tf.Tensor): Tensor fed as input to the model

        Returns:
            tf.Tensor: the model's output layer Tensor
       """

        cell = tf.nn.rnn_cell.GRUCell(10)
        with graph.as_default():
            ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=cell,
                cell_bw=cell,
                inputs=input_tensor,
                sequence_length=[10] * 32,
                dtype=tf.float32,
                swap_memory=True,
                scope=None)
            outputs = tf.concat((fw_outputs, bw_outputs), 2)
            mean = tf.reduce_mean(outputs, axis=1)
            dense = tf.layers.dense(mean, 5, activation=None)

            return dense


    def get_opt_op(graph, logits, labels_tensor):
       """Create optimization operation from model's logits and labels

        Args:
            graph (tf.Graph): Tensors' graph
            logits (tf.Tensor): The model's output without activation
            labels_tensor (tf.Tensor): Target labels

        Returns:
            tf.Operation: the operation performing a stem of Adam optimizer
       """

        with graph.as_default():
            with tf.variable_scope('loss'):
                loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                        logits=logits, labels=labels_tensor, name='xent'),
                        name="mean-xent"
                        )
            with tf.variable_scope('optimizer'):
                opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
            return opt_op


    if __name__ == '__main__':
        # Set random seed for reproducibility
        # and create synthetic data
        np.random.seed(0)
        features = np.random.randn(64, 10, 30)
        labels = np.eye(5)[np.random.randint(0, 5, (64,))]

        graph1 = tf.Graph()
        with graph1.as_default():
            # Random seed for reproducibility
            tf.set_random_seed(0)
            # Placeholders
            batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
            features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
            labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
            # Dataset
            dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
            dataset = dataset.batch(batch_size_ph)
            iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
            dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
            input_tensor, labels_tensor = iterator.get_next()

            # Model
            logits = model(graph1, input_tensor)
            # Optimization
            opt_op = get_opt_op(graph1, logits, labels_tensor)

            with tf.Session(graph=graph1) as sess:
                # Initialize variables
                tf.global_variables_initializer().run(session=sess)
                for epoch in range(3):
                    batch = 0
                    # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                    sess.run(
                        dataset_init_op,
                        feed_dict={
                            features_data_ph: features,
                            labels_data_ph: labels,
                            batch_size_ph: 32
                        })
                    values = []
                    while True:
                        try:
                            if epoch < 2:
                                # Training
                                _, value = sess.run([opt_op, logits])
                                print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                                batch += 1
                            else:
                                # Final inference
                                values.append(sess.run(logits))
                                print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                                batch += 1
                        except tf.errors.OutOfRangeError:
                            break
                # Save model state
                print('
    Saving...'
    )
                cwd = os.getcwd()
                path = os.path.join(cwd, 'simple')
                shutil.rmtree(path, ignore_errors=True)
                inputs_dict = {
                   "batch_size_ph": batch_size_ph,
                   "features_data_ph": features_data_ph,
                   "labels_data_ph": labels_data_ph
                }
                outputs_dict = {
                   "logits": logits
                }
                tf.saved_model.simple_save(
                    sess, path, inputs_dict, outputs_dict
                )
                print('Ok')
        # Restoring
        graph2 = tf.Graph()
        with graph2.as_default():
            with tf.Session(graph=graph2) as sess:
                # Restore saved values
                print('
    Restoring...'
    )
                tf.saved_model.loader.load(
                    sess,
                    [tag_constants.SERVING],
                    path
                )
                print('Ok')
                # Get restored placeholders
                labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
                features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
                batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
                # Get restored model output
                restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
                # Get dataset initializing operation
                dataset_init_op = graph2.get_operation_by_name('dataset_init')

                # Initialize restored dataset
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    }

                )
                # Compute inference for both batches in dataset
                restored_values = []
                for i in range(2):
                    restored_values.append(sess.run(restored_logits))
                    print('Restored values: ', restored_values[i][0])

        # Check if original inference and restored inference are equal
        valid = all((v == rv).all() for v, rv in zip(values, restored_values))
        print('
    Inferences match: '
    , valid)

    这将打印:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    $ python3 save_and_restore.py

    Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
    Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
    Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
    Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
    Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
    Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

    Saving...
    INFO:tensorflow:Assets added to graph.
    INFO:tensorflow:No assets to write.
    INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
    Ok

    Restoring...
    INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
    Ok
    Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
    Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

    Inferences match:  True


    我的环境:Python 3.6,Tensorflow 1.3.0

    虽然有很多解决方案,但大多数都基于tf.train.Saver。当我们加载由Saver保存的.ckpt时,我们必须重新定义张量流网络或使用一些奇怪且难以记住的名称,例如'placehold_0:0''dense/Adam/Weight:0'。在这里,我建议使用tf.saved_model,下面给出一个最简单的示例,您可以从服务TensorFlow模型中了解更多信息:

    保存模型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    import tensorflow as tf

    # define the tensorflow network and do some trains
    x = tf.placeholder("float", name="x")
    w = tf.Variable(2.0, name="w")
    b = tf.Variable(0.0, name="bias")

    h = tf.multiply(x, w)
    y = tf.add(h, b, name="y")
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # save the model
    export_path =  './savedmodel'
    builder = tf.saved_model.builder.SavedModelBuilder(export_path)

    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

    prediction_signature = (
      tf.saved_model.signature_def_utils.build_signature_def(
          inputs={'x_input': tensor_info_x},
          outputs={'y_output': tensor_info_y},
          method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

    builder.add_meta_graph_and_variables(
      sess, [tf.saved_model.tag_constants.SERVING],
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
              prediction_signature
      },
      )
    builder.save()

    加载模型:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    import tensorflow as tf
    sess=tf.Session()
    signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    input_key = 'x_input'
    output_key = 'y_output'

    export_path =  './savedmodel'
    meta_graph_def = tf.saved_model.loader.load(
               sess,
              [tf.saved_model.tag_constants.SERVING],
              export_path)
    signature = meta_graph_def.signature_def

    x_tensor_name = signature[signature_key].inputs[input_key].name
    y_tensor_name = signature[signature_key].outputs[output_key].name

    x = sess.graph.get_tensor_by_name(x_tensor_name)
    y = sess.graph.get_tensor_by_name(y_tensor_name)

    y_out = sess.run(y, {x: 3.0})


    模型有两个部分,模型定义,在模型目录中由Supervisor保存为graph.pbtxt,张量的数值保存在检查点文件中,如model.ckpt-1003418

    可以使用tf.import_graph_def恢复模型定义,并使用Saver恢复权重。

    但是,Saver使用附加到模型Graph的变量的特殊集合保持列表,并且此集合未使用import_graph_def进行初始化,因此您暂时不能将这两个集合在一起(它在我们的路线图中进行修复)。目前,您必须使用Ryan Sepassi的方法 - 手动构建具有相同节点名称的图形,并使用Saver将权重加载到其中。

    (或者你可以通过使用import_graph_def,手动创建变量,并为每个变量使用tf.add_to_collection(tf.GraphKeys.VARIABLES, variable),然后使用Saver来破解它


    你也可以采取这种更简单的方式。

    第1步:初始化所有变量

    1
    2
    3
    4
    W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
    B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

    Similarly, W2, B2, W3, .....

    第2步:将会话保存在模型Saver中并保存

    1
    2
    3
    4
    model_saver = tf.train.Saver()

    # Train the model and save it in the end
    model_saver.save(session,"saved_models/CNN_New.ckpt")

    第3步:恢复模型

    1
    2
    3
    4
    with tf.Session(graph=graph_cnn) as session:
        model_saver.restore(session,"saved_models/CNN_New.ckpt")
        print("Model restored.")
        print('Initialized')

    第4步:检查你的变量

    1
    2
    W1 = session.run(W1)
    print(W1)

    在不同的python实例中运行时,请使用

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    with tf.Session() as sess:
        # Restore latest checkpoint
        saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

        # Initalize the variables
        sess.run(tf.global_variables_initializer())

        # Get default graph (supply your custom graph if you have one)
        graph = tf.get_default_graph()

        # It will give tensor object
        W1 = graph.get_tensor_by_name('W1:0')

        # To get the value (numpy array)
        W1_value = session.run(W1)


    在大多数情况下,使用tf.train.Saver从磁盘保存和恢复是最佳选择:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    ... # build your model
    saver = tf.train.Saver()

    with tf.Session() as sess:
        ... # train the model
        saver.save(sess,"/tmp/my_great_model")

    with tf.Session() as sess:
        saver.restore(sess,"/tmp/my_great_model")
        ... # use the model

    您还可以保存/恢复图形结构本身(有关详细信息,请参阅MetaGraph文档)。默认情况下,Saver将图形结构保存到.meta文件中。您可以调用import_meta_graph()来恢复它。它恢复图形结构并返回一个Saver,您可以使用它来恢复模型的状态:

    1
    2
    3
    4
    5
    saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

    with tf.Session() as sess:
        saver.restore(sess,"/tmp/my_great_model")
        ... # use the model

    但是,有些情况下你需要更快的东西。例如,如果您实施提前停止,则希望每次模型在训练期间改进时保存检查点(在验证集上测量),然后如果一段时间没有进展,则需要回滚到最佳模型。如果每次改进时将模型保存到磁盘,都会极大地减慢培训速度。诀窍是将变量状态保存到内存,然后稍后恢复它们:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    ... # build your model

    # get a handle on the graph nodes we need to save/restore the model
    graph = tf.get_default_graph()
    gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    assign_ops = [graph.get_operation_by_name(v.op.name +"/Assign") for v in gvars]
    init_values = [assign_op.inputs[1] for assign_op in assign_ops]

    with tf.Session() as sess:
        ... # train the model

        # when needed, save the model state to memory
        gvars_state = sess.run(gvars)

        # when needed, restore the model state
        feed_dict = {init_value: val
                     for init_value, val in zip(init_values, gvars_state)}
        sess.run(assign_ops, feed_dict=feed_dict)

    快速解释:当您创建变量X时,TensorFlow会自动创建赋值操作X/Assign以设置变量的初始值。我们只使用这些现有的赋值操作,而不是创建占位符和额外的赋值操作(这会使图形变得混乱)。每个赋值op的第一个输入是对它应该初始化的变量的引用,第二个输入(assign_op.inputs[1])是初始值。因此,为了设置我们想要的任何值(而不是初始值),我们需要使用feed_dict并替换初始值。是的,TensorFlow允许您为任何操作提供值,而不仅仅是占位符,所以这样可以正常工作。


    正如Yaroslav所说,你可以通过导入图形,手动创建变量,然后使用Saver来修复graph_def和checkpoint。

    我实现了这个用于个人用途,所以我虽然在这里分享代码。

    链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

    (当然,这是一个黑客攻击,并且无法保证以这种方式保存的模型在未来的TensorFlow版本中仍然可读。)


    如果它是内部保存的模型,则只需为所有变量指定恢复器

    1
    restorer = tf.train.Saver(tf.all_variables())

    并使用它来恢复当前会话中的变量:

    1
    restorer.restore(self._sess, model_file)

    对于外部模型,您需要指定从其变量名到变量名的映射。您可以使用该命令查看模型变量名称

    1
    python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

    inspect_checkpoint.py脚本可以在Tensorflow源的"./tensorflow/python/tools"文件夹中找到。

    要指定映射,可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和重新训练不同的模型。它包括一个重新训练ResNet模型的例子,位于这里


    这是我对两个基本情况的简单解决方案,它们是关于是否要从文件加载图形或在运行时构建它。

    这个答案适用于Tensorflow 0.12+(包括1.0)。

    在代码中重建图形

    保存

    1
    2
    3
    4
    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')

    载入中

    1
    2
    3
    4
    5
    graph = ... # build the graph
    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        # now you can use the graph, continue training or whatever

    从文件中加载图表

    使用此技术时,请确保所有图层/变量都已明确设置唯一名称。否则,Tensorflow将使名称本身唯一,因此它们将与文件中存储的名称不同。这在以前的技术中不是问题,因为名称在加载和保存时都以相同的方式被"损坏"。

    保存

    1
    2
    3
    4
    5
    6
    7
    8
    graph = ... # build the graph

    for op in [ ... ]:  # operators you want to use after restoring the model
        tf.add_to_collection('ops_to_restore', op)

    saver = tf.train.Saver()  # create the saver after the graph
    with ... as sess:  # your session object
        saver.save(sess, 'my-model')

    载入中

    1
    2
    3
    4
    with ... as sess:  # your session object
        saver = tf.train.import_meta_graph('my-model.meta')
        saver.restore(sess, tf.train.latest_checkpoint('./'))
        ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection


    您还可以查看TensorFlow / skflow中的示例,它提供了saverestore方法,可以帮助您轻松管理模型。它具有参数,您还可以控制备份模型的频率。


    如果使用tf.train.MonitoredTrainingSession作为默认会话,则无需添加额外代码来执行保存/恢复操作。只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。


    这里的所有答案都很棒,但我想添加两件事。

    首先,要详细说明@ user7505159的答案,"./"对于添加到要还原的文件名的开头很重要。

    例如,您可以在文件名中保存没有"./"的图形,如下所示:

    1
    2
    3
    4
    5
    6
    7
    8
    # Some graph defined up here with specific names

    saver = tf.train.Saver()
    save_file = 'model.ckpt'

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.save(sess, save_file)

    但是为了恢复图形,您可能需要在文件名前加上"./":

    1
    2
    3
    4
    5
    6
    7
    8
    # Same graph defined up here

    saver = tf.train.Saver()
    save_file = './' + 'model.ckpt' # String addition used for emphasis

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        saver.restore(sess, save_file)

    您并不总是需要"./",但它可能会导致问题,具体取决于您的环境和TensorFlow版本。

    它还想提到在恢复会话之前sess.run(tf.global_variables_initializer())可能很重要。

    如果在尝试还原已保存的会话时收到有关未初始化变量的错误,请确保在saver.restore(sess, save_file)行之前包含sess.run(tf.global_variables_initializer())。它可以让你头疼。


    如问题6255中所述:

    1
    2
    use '**./**model_name.ckpt'
    saver.restore(sess,'./my_model_final.ckpt')

    代替

    1
    saver.restore('my_model_final.ckpt')

    根据新的Tensorflow版本,tf.train.Checkpoint是保存和恢复模型的首选方式:

    Checkpoint.save and Checkpoint.restore write and read object-based
    checkpoints, in contrast to tf.train.Saver which writes and reads
    variable.name based checkpoints. Object-based checkpointing saves a
    graph of dependencies between Python objects (Layers, Optimizers,
    Variables, etc.) with named edges, and this graph is used to match
    variables when restoring a checkpoint. It can be more robust to
    changes in the Python program, and helps to support restore-on-create
    for variables when executing eagerly. Prefer tf.train.Checkpoint over
    tf.train.Saver for new code.

    这是一个例子:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    import tensorflow as tf
    import os

    tf.enable_eager_execution()

    checkpoint_directory ="/tmp/training_checkpoints"
    checkpoint_prefix = os.path.join(checkpoint_directory,"ckpt")

    checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
    status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
    for _ in range(num_training_steps):
      optimizer.minimize( ... )  # Variables will be restored on creation.
    status.assert_consumed()  # Optional sanity checks.
    checkpoint.save(file_prefix=checkpoint_prefix)

    更多信息和示例在这里。


    对于tensorflow 2.0,它很简单

    1
    2
    # Save the model
    model.save('path_to_my_model.h5')

    恢复:

    1
    new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')


    如果要减小模型大小,请使用tf.train.Saver保存模型,remerber,需要指定var_list。 val_list可以是tf.trainable_variables或tf.global_variables。


    您可以使用保存变量在网络中

    1
    2
    saver = tf.train.Saver()
    saver.save(sess, 'path of save/fileName.ckpt')

    要恢复网络以便以后或在其他脚本中重复使用,请使用:

    1
    2
    3
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint('path of save/')
    sess.run(....)

    重点:

  • 第一次和后续运行之间的sess必须相同(相干结构)。
  • saver.restore需要保存文件的文件夹路径,而不是单个文件路径。

  • 无论您想要保存模型,

    1
    2
    3
    4
    5
    self.saver = tf.train.Saver()
    with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                ...
                self.saver.save(sess, filename)

    确保所有tf.Variable都有名称,因为您可能希望稍后使用其名称还原它们。
    在你想要预测的地方,

    1
    2
    3
    4
    5
    saver = tf.train.import_meta_graph(filename)
    name = 'name given when you saved the file'
    with tf.Session() as sess:
          saver.restore(sess, name)
          print(sess.run('W1:0')) #example to retrieve by variable name

    确保在相应的会话中运行保护程序。
    请记住,如果使用tf.train.latest_checkpoint('./'),则只使用最新的检查点。


    我在版本:

    1
    2
    tensorflow (1.13.1)
    tensorflow-gpu (1.13.1)

    简单的方法是

    保存:

    1
    model.save("model.h5")

    恢复:

    1
    model = tf.keras.models.load_model("model.h5")