成都网站建设推广详,二维码生成短链接,就业合同模板,揭阳网站制作方案定制将训练好的模型参数保存起来#xff0c;以便以后进行验证或测试#xff0c;这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存#xff0c;先要创建一个Saver对象#xff1a;如
savertf.train.Saver()
在创建这个Saver对象的时候#xff…将训练好的模型参数保存起来以便以后进行验证或测试这是我们经常要做的事情。tf里面提供模型保存的是tf.train.Saver()模块。
模型保存先要创建一个Saver对象如
savertf.train.Saver()
在创建这个Saver对象的时候有一个参数我们经常会用到就是 max_to_keep 参数这个是用来设置保存模型的个数默认为5即 max_to_keep5保存最近的5个模型。如果你想每训练一代epoch)就想保存一次模型则可以将 max_to_keep设置为None或者0如
savertf.train.Saver(max_to_keep0)
但是这样做除了多占用硬盘并没有实际多大的用处因此不推荐。
当然如果你只想保存最后一代的模型则只需要将max_to_keep设置为1即可即
savertf.train.Saver(max_to_keep1)
创建完saver对象后就可以保存训练好的模型了如
saver.save(sess,ckpt/mnist.ckpt,global_stepstep)
第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字第三个参数将训练的次数作为后缀加入到模型名字中。
saver.save(sess,my-model, global_step0) filename: my-model-0 ... saver.save(sess, my-model, global_step1000) filename: my-model-1000
看一个mnist实例 # -*- coding:utf-8 -*- Created on SunJun 4 10:29:48 2017 author:Administrator import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist input_data.read_data_sets(MNIST_data/, one_hotFalse) x tf.placeholder(tf.float32, [None, 784])
y_tf.placeholder(tf.int32,[None,]) dense1 tf.layers.dense(inputsx, units1024, activationtf.nn.relu, kernel_initializertf.truncated_normal_initializer(stddev0.01), kernel_regularizertf.nn.l2_loss)
dense2tf.layers.dense(inputsdense1, units512, activationtf.nn.relu, kernel_initializertf.truncated_normal_initializer(stddev0.01), kernel_regularizertf.nn.l2_loss)
logitstf.layers.dense(inputsdense2, units10, activationNone, kernel_initializertf.truncated_normal_initializer(stddev0.01), kernel_regularizertf.nn.l2_loss) losstf.losses.sparse_softmax_cross_entropy(labelsy_,logitslogits)
train_optf.train.AdamOptimizer(learning_rate0.001).minimize(loss)
correct_prediction tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)
acctf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sesstf.InteractiveSession()
sess.run(tf.global_variables_initializer()) savertf.train.Saver(max_to_keep1)
for i in range(100): batch_xs, batch_ys mnist.train.next_batch(100) sess.run(train_op, feed_dict{x: batch_xs,y_: batch_ys}) val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels}) print(epoch:%d, val_loss:%f, val_acc:%f%(i,val_loss,val_acc)) saver.save(sess,ckpt/mnist.ckpt,global_stepi1)
sess.close() 代码中红色部分就是保存模型的代码虽然我在每训练完一代的时候都进行了保存但后一次保存的模型会覆盖前一次的最终只会保存最后一次。因此我们可以节省时间将保存代码放到循环之外仅适用max_to_keep1,否则还是需要放在循环内).
在实验中最后一代可能并不是验证精度最高的一代因此我们并不想默认保存最后一代而是想保存验证精度最高的一代则加个中间变量和判断语句就可以了。 savertf.train.Saver(max_to_keep1)
max_acc0
for i in range(100): batch_xs, batch_ys mnist.train.next_batch(100) sess.run(train_op, feed_dict{x: batch_xs,y_: batch_ys}) val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels}) print(epoch:%d, val_loss:%f, val_acc:%f%(i,val_loss,val_acc)) ifval_accmax_acc: max_accval_acc saver.save(sess,ckpt/mnist.ckpt,global_stepi1)
sess.close() 如果我们想保存验证精度最高的三代且把每次的验证精度也随之保存下来则我们可以生成一个txt文件用于保存。 savertf.train.Saver(max_to_keep3)
max_acc0
fopen(ckpt/acc.txt,w)
for i in range(100): batch_xs, batch_ys mnist.train.next_batch(100) sess.run(train_op, feed_dict{x: batch_xs,y_: batch_ys}) val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels}) print(epoch:%d, val_loss:%f, val_acc:%f%(i,val_loss,val_acc)) f.write(str(i1), val_acc:str(val_acc)\n) if val_accmax_acc: max_accval_acc saver.save(sess,ckpt/mnist.ckpt,global_stepi1)
f.close()
sess.close() 模型的恢复用的是restore()函数它需要两个参数restore(sess,save_path)save_path指的是保存的模型路径。我们可以使用tf.train.latest_checkpoint来自动获取最后一次保存的模型。如
model_filetf.train.latest_checkpoint(ckpt/)
saver.restore(sess,model_file)
则程序后半段代码我们可以改为 sesstf.InteractiveSession()
sess.run(tf.global_variables_initializer()) is_trainFalse
savertf.train.Saver(max_to_keep3) #训练阶段
if is_train: max_acc0 fopen(ckpt/acc.txt,w) for i in range(100): batch_xs, batch_ys mnist.train.next_batch(100) sess.run(train_op, feed_dict{x:batch_xs, y_: batch_ys}) val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels}) print(epoch:%d, val_loss:%f, val_acc:%f%(i,val_loss,val_acc)) f.write(str(i1), val_acc:str(val_acc)\n) if val_accmax_acc: max_accval_acc saver.save(sess,ckpt/mnist.ckpt,global_stepi1) f.close() #验证阶段
else: model_filetf.train.latest_checkpoint(ckpt/) saver.restore(sess,model_file) val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels}) print(val_loss:%f, val_acc:%f%(val_loss,val_acc))
sess.close() 标红的地方就是与保存、恢复模型相关的代码。用一个bool型变量is_train来控制训练和验证两个阶段。
整个源程序 # -*- coding:utf-8 -*-Created on SunJun 4 10:29:48 2017author:Administratorimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist input_data.read_data_sets(MNIST_data/, one_hotFalse)x tf.placeholder(tf.float32, [None, 784])y_tf.placeholder(tf.int32,[None,])dense1 tf.layers.dense(inputsx,units1024,activationtf.nn.relu,kernel_initializertf.truncated_normal_initializer(stddev0.01),kernel_regularizertf.nn.l2_loss)dense2tf.layers.dense(inputsdense1,units512,activationtf.nn.relu,kernel_initializertf.truncated_normal_initializer(stddev0.01),kernel_regularizertf.nn.l2_loss)logitstf.layers.dense(inputsdense2,units10,activationNone,kernel_initializertf.truncated_normal_initializer(stddev0.01),kernel_regularizertf.nn.l2_loss)losstf.losses.sparse_softmax_cross_entropy(labelsy_,logitslogits)train_optf.train.AdamOptimizer(learning_rate0.001).minimize(loss)correct_prediction tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_) acctf.reduce_mean(tf.cast(correct_prediction, tf.float32))sesstf.InteractiveSession() sess.run(tf.global_variables_initializer())is_trainTruesavertf.train.Saver(max_to_keep3)#训练阶段if is_train:max_acc0fopen(ckpt/acc.txt,w)for i in range(100):batch_xs, batch_ys mnist.train.next_batch(100)sess.run(train_op, feed_dict{x:batch_xs, y_: batch_ys})val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels})print(epoch:%d, val_loss:%f, val_acc:%f%(i,val_loss,val_acc))f.write(str(i1), val_acc: str(val_acc)\n)if val_accmax_acc:max_accval_accsaver.save(sess,ckpt/mnist.ckpt,global_stepi1)f.close()#验证阶段else:model_filetf.train.latest_checkpoint(ckpt/)saver.restore(sess,model_file)val_loss,val_accsess.run([loss,acc],feed_dict{x: mnist.test.images, y_: mnist.test.labels})print(val_loss:%f, val_acc:%f%(val_loss,val_acc))sess.close()