TensorFlow模型保存和提取方法示例
本篇文章主要介绍了TensorFlow模型保存和提取方法示例,现在分享给大家,也给大家做个参考。一起过来看看吧
一、TensorFlow模型保存和提取方法
1. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取。tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,"Model/model.ckpt"),实际在这个文件目录下会生成4个人文件:
checkpoint文件保存了一个录下多有的模型文件列表,model.ckpt.meta保存了TensorFlow计算图的结构信息,model.ckpt保存每个变量的取值,此处文件名的写入方式会因不同参数的设置而不同,但加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的。
2. 加载这个已保存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt"),加载模型的代码中也要定义TensorFlow计算图上的所有运算并声明一个tf.train.Saver类,不同的是加载模型时不需要进行变量的初始化,而是将变量的取值通过保存的模型加载进来,注意加载路径的写法。若不希望重复定义计算图上的运算,可直接加载已经持久化的图,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")。
3.tf.train.Saver类也支持在保存和加载时给变量重命名,声明Saver类对象的时候使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名},saver = tf.train.Saver({"v1":u1, "v2": u2})即原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中。
4. 上一条做的目的之一就是方便使用变量的滑动平均值。如果在加载模型时直接将影子变量映射到变量自身,则在使用训练好的模型时就不需要再调用函数来获取变量的滑动平均值了。载入时,声明Saver类对象时通过一个字典将滑动平均值直接加载到新的变量中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典。
此外,通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中。
二、TensorFlow程序实现
# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行 # 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确 # Part1: 通过tf.train.Saver类实现保存和载入神经网络模型 # 执行本段程序时注意当前的工作路径 import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.save(sess, "Model/model.ckpt") # Part2: 加载TensorFlow模型的方法 import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" print(sess.run(result)) # [ 3.] # Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图 import tensorflow as tf saver = tf.train.import_meta_graph("Model/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.] # Part4: tf.train.Saver类也支持在保存和加载时给变量重命名 import tensorflow as tf # 声明的变量名称name与已保存的模型中的变量名称name不一致 u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") result = u1 + u2 # 若直接生命Saver类对象,会报错变量找不到 # 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} # 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 saver = tf.train.Saver({"v1": u1, "v2": u2}) with tf.Session() as sess: saver.restore(sess, "./Model/model.ckpt") print(sess.run(result)) # [ 3.] # Part5: 保存滑动平均模型 import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") for variables in tf.global_variables(): print(variables.name) # v:0 ema = tf.train.ExponentialMovingAverage(0.99) maintain_averages_op = ema.apply(tf.global_variables()) for variables in tf.global_variables(): print(variables.name) # v:0 # v/ExponentialMovingAverage:0 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.assign(v, 10)) sess.run(maintain_averages_op) saver.save(sess, "Model/model_ema.ckpt") print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905] # Part6: 通过变量重命名直接读取变量的滑动平均值 import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) with tf.Session() as sess: saver.restore(sess, "./Model/model_ema.ckpt") print(sess.run(v)) # 0.0999999 # Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典 import tensorflow as tf v = tf.Variable(0, dtype=tf.float32, name="v") # 注意此处的变量名称name一定要与已保存的变量名称一致 ema = tf.train.ExponentialMovingAverage(0.99) print(ema.variables_to_restore()) # {'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>} # 此处的v取自上面变量v的名称name="v" saver = tf.train.Saver(ema.variables_to_restore()) with tf.Session() as sess: saver.restore(sess, "./Model/model_ema.ckpt") print(sess.run(v)) # 0.0999999 # Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中 import tensorflow as tf from tensorflow.python.framework import graph_util v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") result = v1 + v2 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 graph_def = tf.get_default_graph().as_graph_def() output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add']) with tf.gfile.GFile("Model/combined_model.pb", 'wb') as f: f.write(output_graph_def.SerializeToString()) # Part9: 载入包含变量及其取值的模型 import tensorflow as tf from tensorflow.python.platform import gfile with tf.Session() as sess: model_filename = "Model/combined_model.pb" with gfile.FastGFile(model_filename, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=["add:0"]) print(sess.run(result)) # [array([ 3.], dtype=float32)]
相关推荐:
以上是TensorFlow模型保存和提取方法示例的详细内容。更多信息请关注PHP中文网其他相关文章!

热AI工具

Undresser.AI Undress
人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover
用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool
免费脱衣服图片

Clothoff.io
AI脱衣机

Video Face Swap
使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热门文章

热工具

记事本++7.3.1
好用且免费的代码编辑器

SublimeText3汉化版
中文版,非常好用

禅工作室 13.0.1
功能强大的PHP集成开发环境

Dreamweaver CS6
视觉化网页开发工具

SublimeText3 Mac版
神级代码编辑软件(SublimeText3)

微信是主流的聊天工具之一,我们可以通过微信认识新的朋友,联系老的朋友,维系朋友之间的情谊。正如天下没有不散的宴席,人与人之间的相处难免会发生意见不合的时候。当一个人极其影响你的情绪,或者在相处的时候发现三观不合,没办法再继续沟通,那么我们可能需要删除微信好友的方法。怎么删除微信好友?删除微信好友的方法第一步:在微信主界面轻触【通讯录】;第二步:点击对应要删除的好友,进入【详细资料】;第三步:点击右上角【...】;第四步:点击下方【删除】即可;第五步:了解后页面提示后,点击【删除联系人】即可;温馨

而后悔莫及、人们常常会因为一些原因不小心将某些联系人删除、微信作为一款广泛使用的社交软件。帮助用户解决这一问题,本文将介绍如何通过简单的方法找回被删除的联系人。1.了解微信联系人删除机制这为我们找回被删除的联系人提供了可能性、微信中的联系人删除机制是将其从通讯录中移除,但并未完全删除。2.使用微信内置“通讯录恢复”功能微信提供了“通讯录恢复”节省时间和精力,用户可以通过该功能快速找回之前被删除的联系人,功能。3.进入微信设置页面点击右下角,打开微信应用“我”再点击右上角设置图标、进入设置页面,,

七彩虹主板在中国国内市场享有较高的知名度和市场占有率,但是有些七彩虹主板的用户还不清楚怎么进入bios进行设置呢?针对这一情况,小编专门为大家带来了两种进入七彩虹主板bios的方法,快来试试吧! 方法一:使用u盘启动快捷键直接进入u盘装系统 七彩虹主板一键启动u盘的快捷键是ESC或F11,首先使用黑鲨装机大师制作一个黑鲨U盘启动盘,然后开启电脑,当看到开机画面的时候,连续按下键盘上的ESC或F11键以后将会进入到一个启动项顺序选择的窗口,将光标移动到显示“USB”的地方,然

番茄小说是一款非常热门的小说阅读软件,我们在番茄小说中经常会有新的小说和漫画可以去阅读,每一本小说和漫画都很有意思,很多小伙伴也想着要去写小说来赚取赚取零花钱,在把自己想要写的小说内容编辑成文字,那么我们要怎么样在这里面去写小说呢?小伙伴们都不知道,那就让我们一起到本站本站中花点时间来看写小说的方法介绍吧。分享番茄小说写小说方法教程 1、首先在手机上打开番茄免费小说app,点击个人中心——作家中心 2、跳转到番茄作家助手页面——点击创建新书在小说的结

手机游戏成为了人们生活中不可或缺的一部分,随着科技的发展。它以其可爱的龙蛋形象和有趣的孵化过程吸引了众多玩家的关注,而其中一款备受瞩目的游戏就是手机版龙蛋。帮助玩家们在游戏中更好地培养和成长自己的小龙,本文将向大家介绍手机版龙蛋的孵化方法。1.选择合适的龙蛋种类玩家需要仔细选择自己喜欢并且适合自己的龙蛋种类,根据游戏中提供的不同种类的龙蛋属性和能力。2.提升孵化机的等级玩家需要通过完成任务和收集道具来提升孵化机的等级,孵化机的等级决定了孵化速度和孵化成功率。3.收集孵化所需的资源玩家需要在游戏中

Win11管理员权限获取方法汇总在Windows11操作系统中,管理员权限是非常重要的权限之一,可以让用户对系统进行各种操作。有时候,我们可能需要获取管理员权限来完成一些操作,比如安装软件、修改系统设置等。下面就为大家总结了一些获取Win11管理员权限的方法,希望能帮助到大家。1.使用快捷键在Windows11系统中,可以通过快捷键的方式快速打开命令提

Oracle版本查询方法详解Oracle是目前世界上最流行的关系型数据库管理系统之一,它提供了丰富的功能和强大的性能,广泛应用于企业中。在进行数据库管理和开发过程中,了解Oracle数据库的版本是非常重要的。本文将详细介绍如何查询Oracle数据库的版本信息,并给出具体的代码示例。查询数据库版本的SQL语句在Oracle数据库中,可以通过执行简单的SQL语句

字体大小的设置成为了一项重要的个性化需求,随着手机成为人们日常生活的重要工具。以满足不同用户的需求、本文将介绍如何通过简单的操作,提升手机使用体验,调整手机字体大小。为什么需要调整手机字体大小-调整字体大小可以使文字更清晰易读-适合不同年龄段用户的阅读需求-方便视力不佳的用户使用手机系统自带字体大小设置功能-如何进入系统设置界面-在设置界面中找到并进入"显示"选项-找到"字体大小"选项并进行调整第三方应用调整字体大小-下载并安装支持字体大小调整的应用程序-打开应用程序并进入相关设置界面-根据个人
