首页 后端开发 Python教程 浅谈Tensorflow模型的保存与恢复加载

浅谈Tensorflow模型的保存与恢复加载

Apr 26, 2018 pm 04:40 PM
tensorflow 保持 恢复

本篇文章主要介绍了浅谈Tensorflow模型的保存与恢复加载,现在分享给大家,也给大家做个参考。一起过来看看吧

近期做了一些反垃圾的工作,除了使用常用的规则匹配过滤等手段,也采用了一些机器学习方法进行分类预测。我们使用TensorFlow进行模型的训练,训练好的模型需要保存,预测阶段我们需要将模型进行加载还原使用,这就涉及TensorFlow模型的保存与恢复加载。

总结一下Tensorflow常用的模型保存方式。

保存checkpoint模型文件(.ckpt)

首先,TensorFlow提供了一个非常方便的api,tf.train.Saver()来保存和还原一个机器学习模型。

模型保存

使用tf.train.Saver()来保存模型文件非常方便,下面是一个简单的例子:


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))
登录后复制


程序生成并保存四个文件(在版本0.11之前只会生成三个文件:checkpoint, model.ckpt, model.ckpt.meta)

  1. checkpoint 文本文件,记录了模型文件的路径信息列表

  2. model.ckpt.data-00000-of-00001 网络权重信息

  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息

  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

以上是tf.train.Saver().save()的基本用法,save()方法还有很多可配置的参数:


tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
登录后复制


加上global_step参数代表在每1000次迭代后保存模型,会在模型文件后加上"-1000",model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt.data-1000-00000-of-00001

每1000次迭代保存一次模型,但是模型的结构信息文件不会变,就只用1000次迭代时保存一下,不用相应的每1000次保存一次,所以当我们不需要保存meta文件时,可以加上write_meta_graph=False参数,如下:


复制代码 代码如下:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
登录后复制

如果想每两小时保存一次模型,并且只保存最新的4个模型,可以加上使用max_to_keep(默认值为5,如果想每训练一个epoch就保存一次,可以将其设置为None或0,但是没啥用不推荐), keep_checkpoint_every_n_hours参数,如下:


复制代码 代码如下:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
登录后复制


同时在tf.train.Saver()类中,如果我们不指定任何信息,则会保存所有的参数信息,我们也可以指定部分想要保存的内容,例如只保存x, y参数(可传入参数list或dict):


tf.train.Saver([x, y]).save(sess, ckpt_file_path)
登录后复制


ps. 在模型训练过程中需要在保存后拿到的变量或参数名属性name不能丢,不然模型还原后不能通过get_tensor_by_name()获取。

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)
登录后复制


首先还原模型结构,然后还原变量(参数)信息,最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

《tensorflow 1.0 学习:模型的保存与恢复(Saver)》有一些Saver使用技巧。

保存单个模型文件(.pb)

我自己运行过Tensorflow的inception-v3的demo,发现运行结束后会生成一个.pb的模型文件,这个文件是作为后续预测或迁移学习使用的,就一个文件,非常炫酷,也十分方便。

这个过程的主要思路是graph_def文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant(使用graph_util.convert_variables_to_constants()函数),即可达到使用一个文件同时存储网络架构与权重的目标。

ps:这里.pb是模型文件的后缀名,当然我们也可以用其它的后缀(使用.pb与google保持一致 ╮(╯▽╰)╭)

模型保存

同样根据上面的例子,一个简单的demo:


import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  # 这里的输出需要加上name属性
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(pb_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
  with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))
登录后复制


程序生成并保存一个文件

model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

  print(sess.run('b:0'))

  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op_to_store:0')

  ret = sess.run(op, {input_x: 5, input_y: 5})
  print(ret)
登录后复制


模型的还原过程与checkpoint差不多一样。

《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。

思考

模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。

同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。

选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。

以上完整代码:github https://github.com/liuyan731/tf_demo

相关推荐:

TensorFlow模型保存和提取方法示例


以上是浅谈Tensorflow模型的保存与恢复加载的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

Video Face Swap

Video Face Swap

使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

<🎜>:泡泡胶模拟器无穷大 - 如何获取和使用皇家钥匙
4 周前 By 尊渡假赌尊渡假赌尊渡假赌
北端:融合系统,解释
4 周前 By 尊渡假赌尊渡假赌尊渡假赌
Mandragora:巫婆树的耳语 - 如何解锁抓钩
3 周前 By 尊渡假赌尊渡假赌尊渡假赌

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

热门话题

Java教程
1672
14
CakePHP 教程
1428
52
Laravel 教程
1332
25
PHP教程
1276
29
C# 教程
1256
24
微信文件过期怎么恢复 微信的过期文件能恢复吗 微信文件过期怎么恢复 微信的过期文件能恢复吗 Feb 22, 2024 pm 02:46 PM

打开微信,在我中选择设置,选择通用后选择存储空间,在存储空间选择管理,选择要恢复文件的对话选择感叹号图标。教程适用型号:iPhone13系统:iOS15.3版本:微信8.0.24解析1首先打开微信,在我的页面中点击设置选项。2接着在设置页面中找到并点击通用选项。3然后在通用页面中点击存储空间。4接下来在存储空间页面中点击管理。5最后选择要恢复文件的对话,点击右侧的感叹号图标。补充:微信文件一般几天过期1要是微信接收的文件并没有点开过的情况下,那在七十二钟头之后微信系统会清除掉,要是己经查看了微信

如何恢复无痕模式下的浏览记录 如何恢复无痕模式下的浏览记录 Feb 19, 2024 pm 04:22 PM

无痕浏览是一种非常方便的浏览方式,可以在使用电脑或移动设备上网时保护个人隐私。无痕浏览模式通常会阻止浏览器记录访问历史、保存Cookie和缓存文件,以及防止正在浏览的网站在浏览器中留下任何痕迹。但是,对于一些特殊的情况,我们可能需要恢复无痕浏览的浏览记录。首先,我们需要明确一点:无痕浏览模式的目的是保护隐私,防止他人从浏览器中获取用户的上网记录。因此,无痕浏

抖音怎么恢复聊天火花 抖音怎么恢复聊天火花 Mar 16, 2024 pm 01:25 PM

在抖音这个充满创意与活力的短视频平台上,我们不仅可以欣赏到各种精彩内容,还能与志同道合的朋友展开深入的交流。其中,聊天火花作为衡量双方互动热度的重要指标,常常在不经意间点燃我们与好友之间的情感纽带。然而,有时由于一些原因,聊天火花可能会断开,那么如果我们想要恢复聊天火花究竟该如何操作呢,这篇教程攻略就将为大家带来详细的内容攻略介绍,希望能帮助到大家。抖音聊天火花断了怎么恢复?1、打开抖音的消息页面,选择好友聊天。2、互发消息聊天。3、连续发消息3天,就可以获得火花标识。在3天基础上,互发图片或视

小红书怎么保存无水印图片 小红书怎么拿图没有水印 小红书怎么保存无水印图片 小红书怎么拿图没有水印 Mar 22, 2024 pm 03:40 PM

  小红书拥有丰富的内容,让大家可以在这里自由的查看,让你们每天都可以使用这个软件解闷,为自己带来帮助,在使用这个软件的过程中,有时候会看到各种的美图,很多人想要保存起来,但是保存后的图片,都有水印,非常的影响,大家都想要知道在这里该怎么保存没有水印的图片,小编为你们提供方法,有需要的小伙伴们,都可以马上的了解使用起来!  1.点击图片右上角的“…”复制链接  2.打开微信小程序  3.微信小程序搜索红薯库  4.进入红薯库确定获取链接  5.获取图片保存至手机相册

小米云相册怎么恢复到本地 小米云相册怎么恢复到本地 Feb 24, 2024 pm 03:28 PM

小米云相册怎么恢复到本地?小米云相册APP中是可以恢复到本地,但是多数的小伙伴不知道小米云相册如何恢复到本地中,接下来就是小编为用户带来的小米云相册恢复到本地方法图文教程,感兴趣的用户快来一起看看吧!小米云相册怎么恢复到本地1、首先打开小米手机中的设置功能,主界面选择【个人头像】;2、然后进入到小米账号的界面,点击【云服务】功能;3、接着跳转到小米云服务的功能,选择其中的【云备份】;4、最后在如下图所示的界面,点击【云相册】即可恢复相册到本地。

恢复win11默认头像的教程 恢复win11默认头像的教程 Jan 02, 2024 pm 12:43 PM

如果我们更换了自己的系统账户头像,但是不想要了,结果找不到win11怎么更改默认头像了,其实我们只要找到默认头像的文件夹就可以恢复了。win11头像恢复默认1、首先点开底部任务栏的“Windows徽标”2、接着在其中找到并打开“设置”3、然后进入左边栏的“账户”4、随后点开右侧的“账户信息”5、打开后,点击选择照片中的“浏览文件”6、最后进入“C:\ProgramData\Microsoft\UserAccountPictures”路径就可以找到系统默认头像图片了。

win10怎么恢复默认壁纸 win10怎么恢复默认壁纸 Feb 10, 2024 pm 10:51 PM

Windows10的2019年5月更新具有新的、更亮的默认桌面背景。它看起来很棒-带有新的浅色主题。如果您使用Windows10的深色主题,您可能需要更深的背景。奇怪的是,Windows10的原始桌面背景已从最新版本的Windows10中删除。您必须从Web下载它或从旧的Windows10PC复制其文件。尽管我们无法在Microsoft的官方网站上找到此壁纸图片,但您可以从其他来源下载它。我们在Imgur上找到了一份4K分辨率的Windows10原始桌面壁纸的副本。此外,还有其他尺寸和更多默认壁

小红书被删除的评论怎么恢复?被删除的评论有提示吗? 小红书被删除的评论怎么恢复?被删除的评论有提示吗? Mar 27, 2024 am 11:56 AM

小红书作为一款流行的社交电商平台,用户可以在这里分享购物心得、生活点滴等。在使用过程中,有些用户可能会遇到自己发布的评论被删除的情况。那么,小红书被删除的评论怎么恢复呢?一、小红书被删除的评论怎么恢复?如果发现评论被误删,用户可以选择静待小红书官方团队进行恢复。在这种情况下,最好保持耐心等待,因为官方团队可能会在一段时间后自动处理并恢复评论。如果您发现评论被删除,可以考虑重新发布类似内容。但在重新发布时,请确保内容符合小红书的社区准则,以免再次遭到删除。3.联系小红书客服:如果认为自己的评论被误

See all articles