Home Backend Development Python Tutorial Saving and restoring the model learned by tensorflow1.0 (Saver)_python

Saving and restoring the model learned by tensorflow1.0 (Saver)_python

Apr 23, 2018 pm 03:42 PM
keep recover

This article mainly introduces the saving and recovery (Saver) of the tensorflow1.0 learning model. Now I share it with you and give you a reference. Let’s take a look together

Save the trained model parameters for future verification or testing. This is something we often do. The tf.train.Saver() module that provides model saving in tf.

To save the model, you must first create a Saver object: such as

saver=tf.train.Saver()
Copy after login

When creating this Saver object, there is a parameter we often What is used is the max_to_keep parameter. This is used to set the number of saved models. The default is 5, that is, max_to_keep=5, which saves the latest 5 models. If you want to save the model every training generation (epoch), you can set max_to_keep to None or 0, such as:

saver=tf.train.Saver(max_to_keep=0)
Copy after login

But this does nothing but Taking up more hard disk space is of little practical use, so it is not recommended.

Of course, if you only want to save the model of the last generation, you only need to set max_to_keep to 1, that is,

saver=tf.train.Saver(max_to_keep=1)
Copy after login

Create After completing the saver object, you can save the trained model, such as:

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
Copy after login

The first parameter sess, this goes without saying. The second parameter sets the saved path and name, and the third parameter adds the number of training times as a suffix to the model name.

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

Look at a mnist instance:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
Copy after login

The red part in the code is the code to save the model. Although I saved it after each generation of training, the model saved the next time will overwrite the previous one, and only the last time will be saved. Therefore, we can save time and put the saving code outside the loop (only applies to max_to_keep=1, otherwise it still needs to be placed inside the loop).

In the experiment, the last generation may not be the generation with the highest verification accuracy , so we don’t want to save the last generation by default, but want to save the generation with the highest verification accuracy, so just add an intermediate variable and a judgment statement.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
Copy after login

If we want to save the three generations with the highest verification accuracy, and also save the verification accuracy of each time, we can generate a txt file with to save.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()
Copy after login

The restore() function of the model is used, which requires two parameters restore(sess, save_path). save_path refers to the saved model path. . We can use tf.train.latest_checkpoint() to automatically get the last saved model. For example:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)
Copy after login

Then we can change the second half of the program to:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
Copy after login

The place marked in red is the code related to saving and restoring the model. Use a bool variable is_train to control the training and verification phases.

Entire source program:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
Copy after login

Related recommendations:

Export TensorFlow’s model network as a single file Methods

The above is the detailed content of Saving and restoring the model learned by tensorflow1.0 (Saver)_python. For more information, please follow other related articles on the PHP Chinese website!

Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn

Hot AI Tools

Undresser.AI Undress

Undresser.AI Undress

AI-powered app for creating realistic nude photos

AI Clothes Remover

AI Clothes Remover

Online AI tool for removing clothes from photos.

Undress AI Tool

Undress AI Tool

Undress images for free

Clothoff.io

Clothoff.io

AI clothes remover

Video Face Swap

Video Face Swap

Swap faces in any video effortlessly with our completely free AI face swap tool!

Hot Tools

Notepad++7.3.1

Notepad++7.3.1

Easy-to-use and free code editor

SublimeText3 Chinese version

SublimeText3 Chinese version

Chinese version, very easy to use

Zend Studio 13.0.1

Zend Studio 13.0.1

Powerful PHP integrated development environment

Dreamweaver CS6

Dreamweaver CS6

Visual web development tools

SublimeText3 Mac version

SublimeText3 Mac version

God-level code editing software (SublimeText3)

How to recover expired WeChat files? Can expired WeChat files be recovered? How to recover expired WeChat files? Can expired WeChat files be recovered? Feb 22, 2024 pm 02:46 PM

Open WeChat, select Settings in Me, select General and then select Storage Space, select Management in Storage Space, select the conversation in which you want to restore files and select the exclamation mark icon. Tutorial Applicable Model: iPhone13 System: iOS15.3 Version: WeChat 8.0.24 Analysis 1 First open WeChat and click the Settings option on the My page. 2 Then find and click General Options on the settings page. 3Then click Storage Space on the general page. 4 Next, click Manage on the storage space page. 5Finally, select the conversation in which you want to recover files and click the exclamation mark icon on the right. Supplement: WeChat files generally expire in a few days. If the file received by WeChat has not been clicked, the WeChat system will clear it after 72 hours. If the WeChat file has been viewed,

How to recover browsing history in incognito mode How to recover browsing history in incognito mode Feb 19, 2024 pm 04:22 PM

Private browsing is a very convenient way to browse and protect your privacy when surfing the Internet on your computer or mobile device. Private browsing mode usually prevents the browser from recording your visit history, saving cookies and cache files, and preventing the website you are browsing from leaving any traces in the browser. However, for some special cases, we may need to restore the browsing history of Incognito Browsing. First of all, we need to make it clear: the purpose of private browsing mode is to protect privacy and prevent others from obtaining the user’s online history from the browser. Therefore, incognito browsing

How to restore chat spark on TikTok How to restore chat spark on TikTok Mar 16, 2024 pm 01:25 PM

On Douyin, a short video platform full of creativity and vitality, we can not only enjoy a variety of exciting content, but also have in-depth communications with like-minded friends. Among them, chat sparks are an important indicator of the intensity of interaction between the two parties, and they often inadvertently ignite the emotional bonds between us and our friends. However, sometimes due to some reasons, the chat spark may be disconnected. So what should we do if we want to restore the chat spark? This tutorial guide will bring you a detailed introduction to the content strategy, hoping to help everyone. How to restore the spark of Douyin chat? 1. Open the Douyin message page and select a friend to chat. 2. Send messages and chat to each other. 3. If you send messages continuously for 3 days, you can get the spark logo. On a 3-day basis, send pictures or videos to each other

How to save pictures without watermark in Xiaohongshu How to save pictures without watermark in Xiaohongshu How to save pictures without watermark in Xiaohongshu How to save pictures without watermark in Xiaohongshu Mar 22, 2024 pm 03:40 PM

Xiaohongshu has rich content that everyone can view freely here, so that you can use this software to relieve boredom every day and help yourself. In the process of using this software, you will sometimes see various beautiful things. Many people want to save pictures, but the saved pictures have watermarks, which is very influential. Everyone wants to know how to save pictures without watermarks here. The editor provides you with a method for those in need. Everyone can understand and use it immediately! 1. Click the "..." in the upper right corner of the picture to copy the link 2. Open the WeChat applet 3. Search the sweet potato library in the WeChat applet 4. Enter the sweet potato library and confirm to get the link 5. Get the picture and save it to the mobile phone album

How to restore Xiaomi Cloud photo album to local How to restore Xiaomi Cloud photo album to local Feb 24, 2024 pm 03:28 PM

How to restore Xiaomi Cloud Photo Album to local? You can restore Xiaomi Cloud Photo Album to local in Xiaomi Cloud Photo Album APP, but most friends don’t know how to restore Xiaomi Cloud Photo Album to local. The next step is to restore Xiaomi Cloud Photo Album to local. Local method graphic tutorials, interested users come and take a look! How to restore Xiaomi cloud photo album to local 1. First open the settings function in Xiaomi phone and select [Personal Avatar] on the main interface; 2. Then enter the Xiaomi account interface and click the [Cloud Service] function; 3. Then jump to Xiaomi For the function of cloud service, select [Cloud Backup]; 4. Finally, in the interface as shown below, click [Cloud Album] to restore the album to local.

Tutorial to restore win11 default avatar Tutorial to restore win11 default avatar Jan 02, 2024 pm 12:43 PM

If we change our system account avatar but don’t want it anymore, we can’t find how to change the default avatar in win11. In fact, we only need to find the folder of the default avatar to restore it. Restore the default avatar in win11 1. First click on the "Windows Logo" on the bottom taskbar 2. Then find and open "Settings" 3. Then enter "Account" on the left column 4. Then click on "Account Information" on the right 5. After opening, click "Browse Files" in the selected photo. 6. Finally, enter the "C:\ProgramData\Microsoft\UserAccountPictures" path to find the system default avatar picture.

How to restore default wallpaper in win10 How to restore default wallpaper in win10 Feb 10, 2024 pm 10:51 PM

Windows 10's May 2019 Update features a new, brighter default desktop background. It looks great - with the new light theme. If you use Windows 10’s dark theme, you may want a darker background. Strangely, the original Windows 10 desktop background has been removed from the latest version of Windows 10. You have to download it from the web or copy its files from an old Windows 10 PC. Although we were unable to find this wallpaper image on Microsoft's official website, you can download it from other sources. We found a copy of the original Windows 10 desktop wallpaper in 4K resolution on Imgur. Additionally, there are other sizes and more default walls

How to recover cleared browsing history How to recover cleared browsing history Feb 18, 2024 pm 10:05 PM

How to restore web page history after it has been cleared Date: June 10, 2022 Introduction: When we use computers or mobile phone browsers daily, we often use the browser's history to find web pages we have visited before. However, sometimes we may accidentally clear our browser history, causing us to be unable to retrieve a specific web page. In this article, I will tell you some ways to recover cleared web history. Method 1: Use the browser recovery function. Most common browsers provide the function of restoring history, such as Google

See all articles