Tensorflow: how to save/restore a model?
在Tensorflow中训练模型后:
我正在改进我的答案,添加更多有关保存和恢复模型的详细信息。
在(和之后)Tensorflow版本0.11:
保存模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | import tensorflow as tf #Prepare to feed input, i.e. feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8} #Define a test operation that we will restore w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) #Create a saver object which will save all the variables saver = tf.train.Saver() #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1 #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000) |
恢复模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | import tensorflow as tf sess=tf.Session() #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./')) # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved # Now, let's access and create placeholders variables and # create feed-dict to feed new data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0} #Now, access the op that you want to run. op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print sess.run(op_to_restore,feed_dict) #This will print 60 which is calculated |
这里和一些更高级的用例已经在这里得到了很好的解释。
一个快速完整的教程,用于保存和恢复Tensorflow模型
在(及之后)TensorFlow版本0.11.0RC1中,您可以根据https://www.tensorflow.org/programmers_guide/meta_graph调用
保存模型
1 2 3 4 5 6 7 8 9 10 | w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta |
恢复模型
1 2 3 4 5 6 7 | sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars: v_ = sess.run(v) print(v_) |
对于TensorFlow版本<0.11.0RC1:
保存的检查点包含模型中
这是一个线性回归的例子,其中有一个训练循环可以保存变量检查点,还有一个评估部分可以恢复先前运行中保存的变量并计算预测。当然,如果您愿意,还可以恢复变量并继续训练。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | x = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32) w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32)) b = tf.Variable(tf.ones([1, 1], dtype=tf.float32)) y_hat = tf.add(b, tf.matmul(x, w)) ...more setup for optimization and what not... saver = tf.train.Saver() # defaults to saving all variables - in this case w and b with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if FLAGS.train: for i in xrange(FLAGS.training_steps): ...training loop... if (i + 1) % FLAGS.checkpoint_steps == 0: saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step=i+1) else: # Here's where you're restoring the variables w and b. # Note that the graph is exactly as it was when the variables were # saved in a prior training run. ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: ...no checkpoint found... # Now you can run the model to get predictions batch_x = ...load some data... predictions = sess.run(y_hat, feed_dict={x: batch_x}) |
以下是
文件
他们构建了一个详尽而有用的教程 - > https://www.tensorflow.org/guide/saved_model
来自文档:
保存
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | # Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables. init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, initialize the variables, do some work, and save the # variables to disk. with tf.Session() as sess: sess.run(init_op) # Do some work with the model. inc_v1.op.run() dec_v2.op.run() # Save the variables to disk. save_path = saver.save(sess,"/tmp/model.ckpt") print("Model saved in path: %s" % save_path) |
恢复
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | tf.reset_default_graph() # Create some variables. v1 = tf.get_variable("v1", shape=[3]) v2 = tf.get_variable("v2", shape=[5]) # Add ops to save and restore all the variables. saver = tf.train.Saver() # Later, launch the model, use the saver to restore variables from disk, and # do some work with the model. with tf.Session() as sess: # Restore variables from disk. saver.restore(sess,"/tmp/model.ckpt") print("Model restored.") # Check the values of the variables print("v1 : %s" % v1.eval()) print("v2 : %s" % v2.eval()) |
Tensorflow 2
这仍然是测试版所以我现在建议反对。如果您仍想沿着这条路前进,请参阅
Tensorflow <2
许多好的答案,为了完整性,我将加上我的2美分:simple_save。也是使用
Python 3; Tensorflow 1.14
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import tensorflow as tf from tensorflow.saved_model import tag_constants with tf.Graph().as_default(): with tf.Session() as sess: ... # Saving inputs = { "batch_size_placeholder": batch_size_placeholder, "features_placeholder": features_placeholder, "labels_placeholder": labels_placeholder, } outputs = {"prediction": model_output} tf.saved_model.simple_save( sess, 'path/to/your/location/', inputs, outputs ) |
恢复:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | graph = tf.Graph() with restored_graph.as_default(): with tf.Session() as sess: tf.saved_model.loader.load( sess, [tag_constants.SERVING], 'path/to/your/location/', ) batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0') features_placeholder = graph.get_tensor_by_name('features_placeholder:0') labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0') prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0') sess.run(prediction, feed_dict={ batch_size_placeholder: some_value, features_placeholder: some_other_value, labels_placeholder: another_value }) |
独立的例子
原创博文
以下代码为演示生成随机数据。
码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | import os import shutil import numpy as np import tensorflow as tf from tensorflow.python.saved_model import tag_constants def model(graph, input_tensor): """Create the model which consists of a bidirectional rnn (GRU(10)) followed by a dense classifier Args: graph (tf.Graph): Tensors' graph input_tensor (tf.Tensor): Tensor fed as input to the model Returns: tf.Tensor: the model's output layer Tensor """ cell = tf.nn.rnn_cell.GRUCell(10) with graph.as_default(): ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell, cell_bw=cell, inputs=input_tensor, sequence_length=[10] * 32, dtype=tf.float32, swap_memory=True, scope=None) outputs = tf.concat((fw_outputs, bw_outputs), 2) mean = tf.reduce_mean(outputs, axis=1) dense = tf.layers.dense(mean, 5, activation=None) return dense def get_opt_op(graph, logits, labels_tensor): """Create optimization operation from model's logits and labels Args: graph (tf.Graph): Tensors' graph logits (tf.Tensor): The model's output without activation labels_tensor (tf.Tensor): Target labels Returns: tf.Operation: the operation performing a stem of Adam optimizer """ with graph.as_default(): with tf.variable_scope('loss'): loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=logits, labels=labels_tensor, name='xent'), name="mean-xent" ) with tf.variable_scope('optimizer'): opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss) return opt_op if __name__ == '__main__': # Set random seed for reproducibility # and create synthetic data np.random.seed(0) features = np.random.randn(64, 10, 30) labels = np.eye(5)[np.random.randint(0, 5, (64,))] graph1 = tf.Graph() with graph1.as_default(): # Random seed for reproducibility tf.set_random_seed(0) # Placeholders batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph') features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph') labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph') # Dataset dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph)) dataset = dataset.batch(batch_size_ph) iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) dataset_init_op = iterator.make_initializer(dataset, name='dataset_init') input_tensor, labels_tensor = iterator.get_next() # Model logits = model(graph1, input_tensor) # Optimization opt_op = get_opt_op(graph1, logits, labels_tensor) with tf.Session(graph=graph1) as sess: # Initialize variables tf.global_variables_initializer().run(session=sess) for epoch in range(3): batch = 0 # Initialize dataset (could feed epochs in Dataset.repeat(epochs)) sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 }) values = [] while True: try: if epoch < 2: # Training _, value = sess.run([opt_op, logits]) print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0])) batch += 1 else: # Final inference values.append(sess.run(logits)) print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0])) batch += 1 except tf.errors.OutOfRangeError: break # Save model state print(' Saving...') cwd = os.getcwd() path = os.path.join(cwd, 'simple') shutil.rmtree(path, ignore_errors=True) inputs_dict = { "batch_size_ph": batch_size_ph, "features_data_ph": features_data_ph, "labels_data_ph": labels_data_ph } outputs_dict = { "logits": logits } tf.saved_model.simple_save( sess, path, inputs_dict, outputs_dict ) print('Ok') # Restoring graph2 = tf.Graph() with graph2.as_default(): with tf.Session(graph=graph2) as sess: # Restore saved values print(' Restoring...') tf.saved_model.loader.load( sess, [tag_constants.SERVING], path ) print('Ok') # Get restored placeholders labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0') features_data_ph = graph2.get_tensor_by_name('features_data_ph:0') batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0') # Get restored model output restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0') # Get dataset initializing operation dataset_init_op = graph2.get_operation_by_name('dataset_init') # Initialize restored dataset sess.run( dataset_init_op, feed_dict={ features_data_ph: features, labels_data_ph: labels, batch_size_ph: 32 } ) # Compute inference for both batches in dataset restored_values = [] for i in range(2): restored_values.append(sess.run(restored_logits)) print('Restored values: ', restored_values[i][0]) # Check if original inference and restored inference are equal valid = all((v == rv).all() for v, rv in zip(values, restored_values)) print(' Inferences match: ', valid) |
这将打印:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | $ python3 save_and_restore.py Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595 0.12804556 0.20013677 -0.08229901] Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045 -0.00107776] Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792 -0.00602257 0.07465433 0.11674127] Epoch 1, batch 1 | Sample value: [-0.05275984 0.05981954 -0.15913513 -0.3244143 0.10673307] Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Saving... INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb' Ok Restoring... INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables' Ok Restored values: [-0.26331693 -0.13013336 -0.12553 -0.04276478 0.2933622 ] Restored values: [-0.07730117 0.11119192 -0.20817074 -0.35660955 0.16990358] Inferences match: True |
我的环境:Python 3.6,Tensorflow 1.3.0
虽然有很多解决方案,但大多数都基于
保存模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 | import tensorflow as tf # define the tensorflow network and do some trains x = tf.placeholder("float", name="x") w = tf.Variable(2.0, name="w") b = tf.Variable(0.0, name="bias") h = tf.multiply(x, w) y = tf.add(h, b, name="y") sess = tf.Session() sess.run(tf.global_variables_initializer()) # save the model export_path = './savedmodel' builder = tf.saved_model.builder.SavedModelBuilder(export_path) tensor_info_x = tf.saved_model.utils.build_tensor_info(x) tensor_info_y = tf.saved_model.utils.build_tensor_info(y) prediction_signature = ( tf.saved_model.signature_def_utils.build_signature_def( inputs={'x_input': tensor_info_x}, outputs={'y_output': tensor_info_y}, method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map={ tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature }, ) builder.save() |
加载模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | import tensorflow as tf sess=tf.Session() signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY input_key = 'x_input' output_key = 'y_output' export_path = './savedmodel' meta_graph_def = tf.saved_model.loader.load( sess, [tf.saved_model.tag_constants.SERVING], export_path) signature = meta_graph_def.signature_def x_tensor_name = signature[signature_key].inputs[input_key].name y_tensor_name = signature[signature_key].outputs[output_key].name x = sess.graph.get_tensor_by_name(x_tensor_name) y = sess.graph.get_tensor_by_name(y_tensor_name) y_out = sess.run(y, {x: 3.0}) |
模型有两个部分,模型定义,在模型目录中由
可以使用
但是,
(或者你可以通过使用
你也可以采取这种更简单的方式。
第1步:初始化所有变量
1 2 3 4 | W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1") B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1") Similarly, W2, B2, W3, ..... |
第2步:将会话保存在模型
1 2 3 4 | model_saver = tf.train.Saver() # Train the model and save it in the end model_saver.save(session,"saved_models/CNN_New.ckpt") |
第3步:恢复模型
1 2 3 4 | with tf.Session(graph=graph_cnn) as session: model_saver.restore(session,"saved_models/CNN_New.ckpt") print("Model restored.") print('Initialized') |
第4步:检查你的变量
1 2 | W1 = session.run(W1) print(W1) |
在不同的python实例中运行时,请使用
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | with tf.Session() as sess: # Restore latest checkpoint saver.restore(sess, tf.train.latest_checkpoint('saved_model/.')) # Initalize the variables sess.run(tf.global_variables_initializer()) # Get default graph (supply your custom graph if you have one) graph = tf.get_default_graph() # It will give tensor object W1 = graph.get_tensor_by_name('W1:0') # To get the value (numpy array) W1_value = session.run(W1) |
在大多数情况下,使用
1 2 3 4 5 6 7 8 9 10 | ... # build your model saver = tf.train.Saver() with tf.Session() as sess: ... # train the model saver.save(sess,"/tmp/my_great_model") with tf.Session() as sess: saver.restore(sess,"/tmp/my_great_model") ... # use the model |
您还可以保存/恢复图形结构本身(有关详细信息,请参阅MetaGraph文档)。默认情况下,
1 2 3 4 5 | saver = tf.train.import_meta_graph("/tmp/my_great_model.meta") with tf.Session() as sess: saver.restore(sess,"/tmp/my_great_model") ... # use the model |
但是,有些情况下你需要更快的东西。例如,如果您实施提前停止,则希望每次模型在训练期间改进时保存检查点(在验证集上测量),然后如果一段时间没有进展,则需要回滚到最佳模型。如果每次改进时将模型保存到磁盘,都会极大地减慢培训速度。诀窍是将变量状态保存到内存,然后稍后恢复它们:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | ... # build your model # get a handle on the graph nodes we need to save/restore the model graph = tf.get_default_graph() gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) assign_ops = [graph.get_operation_by_name(v.op.name +"/Assign") for v in gvars] init_values = [assign_op.inputs[1] for assign_op in assign_ops] with tf.Session() as sess: ... # train the model # when needed, save the model state to memory gvars_state = sess.run(gvars) # when needed, restore the model state feed_dict = {init_value: val for init_value, val in zip(init_values, gvars_state)} sess.run(assign_ops, feed_dict=feed_dict) |
快速解释:当您创建变量
正如Yaroslav所说,你可以通过导入图形,手动创建变量,然后使用Saver来修复graph_def和checkpoint。
我实现了这个用于个人用途,所以我虽然在这里分享代码。
链接:https://gist.github.com/nikitakit/6ef3b72be67b86cb7868
(当然,这是一个黑客攻击,并且无法保证以这种方式保存的模型在未来的TensorFlow版本中仍然可读。)
如果它是内部保存的模型,则只需为所有变量指定恢复器
1 | restorer = tf.train.Saver(tf.all_variables()) |
并使用它来恢复当前会话中的变量:
1 | restorer.restore(self._sess, model_file) |
对于外部模型,您需要指定从其变量名到变量名的映射。您可以使用该命令查看模型变量名称
1 | python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt |
inspect_checkpoint.py脚本可以在Tensorflow源的"./tensorflow/python/tools"文件夹中找到。
要指定映射,可以使用我的Tensorflow-Worklab,它包含一组类和脚本来训练和重新训练不同的模型。它包括一个重新训练ResNet模型的例子,位于这里
这是我对两个基本情况的简单解决方案,它们是关于是否要从文件加载图形或在运行时构建它。
这个答案适用于Tensorflow 0.12+(包括1.0)。
在代码中重建图形
保存
1 2 3 4 | graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') |
载入中
1 2 3 4 5 | graph = ... # build the graph saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.restore(sess, tf.train.latest_checkpoint('./')) # now you can use the graph, continue training or whatever |
从文件中加载图表
使用此技术时,请确保所有图层/变量都已明确设置唯一名称。否则,Tensorflow将使名称本身唯一,因此它们将与文件中存储的名称不同。这在以前的技术中不是问题,因为名称在加载和保存时都以相同的方式被"损坏"。
保存
1 2 3 4 5 6 7 8 | graph = ... # build the graph for op in [ ... ]: # operators you want to use after restoring the model tf.add_to_collection('ops_to_restore', op) saver = tf.train.Saver() # create the saver after the graph with ... as sess: # your session object saver.save(sess, 'my-model') |
载入中
1 2 3 4 | with ... as sess: # your session object saver = tf.train.import_meta_graph('my-model.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) ops = tf.get_collection('ops_to_restore') # here are your operators in the same order in which you saved them to the collection |
您还可以查看TensorFlow / skflow中的示例,它提供了
如果使用tf.train.MonitoredTrainingSession作为默认会话,则无需添加额外代码来执行保存/恢复操作。只需将检查点目录名称传递给MonitoredTrainingSession的构造函数,它将使用会话挂钩来处理这些。
这里的所有答案都很棒,但我想添加两件事。
首先,要详细说明@ user7505159的答案,"./"对于添加到要还原的文件名的开头很重要。
例如,您可以在文件名中保存没有"./"的图形,如下所示:
1 2 3 4 5 6 7 8 | # Some graph defined up here with specific names saver = tf.train.Saver() save_file = 'model.ckpt' with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, save_file) |
但是为了恢复图形,您可能需要在文件名前加上"./":
1 2 3 4 5 6 7 8 | # Same graph defined up here saver = tf.train.Saver() save_file = './' + 'model.ckpt' # String addition used for emphasis with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, save_file) |
您并不总是需要"./",但它可能会导致问题,具体取决于您的环境和TensorFlow版本。
它还想提到在恢复会话之前
如果在尝试还原已保存的会话时收到有关未初始化变量的错误,请确保在
如问题6255中所述:
1 2 | use '**./**model_name.ckpt' saver.restore(sess,'./my_model_final.ckpt') |
代替
1 | saver.restore('my_model_final.ckpt') |
根据新的Tensorflow版本,
Checkpoint.save andCheckpoint.restore write and read object-based
checkpoints, in contrast to tf.train.Saver which writes and reads
variable.name based checkpoints. Object-based checkpointing saves a
graph of dependencies between Python objects (Layers, Optimizers,
Variables, etc.) with named edges, and this graph is used to match
variables when restoring a checkpoint. It can be more robust to
changes in the Python program, and helps to support restore-on-create
for variables when executing eagerly. Prefertf.train.Checkpoint over
tf.train.Saver for new code.
这是一个例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import tensorflow as tf import os tf.enable_eager_execution() checkpoint_directory ="/tmp/training_checkpoints" checkpoint_prefix = os.path.join(checkpoint_directory,"ckpt") checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory)) for _ in range(num_training_steps): optimizer.minimize( ... ) # Variables will be restored on creation. status.assert_consumed() # Optional sanity checks. checkpoint.save(file_prefix=checkpoint_prefix) |
更多信息和示例在这里。
对于tensorflow 2.0,它很简单
1
2 # Save the model
model.save('path_to_my_model.h5')
恢复:
1 | new_model = tensorflow.keras.models.load_model('path_to_my_model.h5') |
如果要减小模型大小,请使用tf.train.Saver保存模型,remerber,需要指定var_list。 val_list可以是tf.trainable_variables或tf.global_variables。
您可以使用保存变量在网络中
1 2 | saver = tf.train.Saver() saver.save(sess, 'path of save/fileName.ckpt') |
要恢复网络以便以后或在其他脚本中重复使用,请使用:
1 2 3 | saver = tf.train.Saver() saver.restore(sess, tf.train.latest_checkpoint('path of save/') sess.run(....) |
重点:
无论您想要保存模型,
1 2 3 4 5 | self.saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) ... self.saver.save(sess, filename) |
确保所有
在你想要预测的地方,
1 2 3 4 5 | saver = tf.train.import_meta_graph(filename) name = 'name given when you saved the file' with tf.Session() as sess: saver.restore(sess, name) print(sess.run('W1:0')) #example to retrieve by variable name |
确保在相应的会话中运行保护程序。
请记住,如果使用
我在版本:
1 2 | tensorflow (1.13.1) tensorflow-gpu (1.13.1) |
简单的方法是
保存:
1 | model.save("model.h5") |
恢复:
1 | model = tf.keras.models.load_model("model.h5") |