TensorFlow 持久化

Session Checkpoint

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
class tf.train.Saver:
def __init__(self,
var_list=None/list/dict,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
):
self.last_checkpoints
  • 用途:保存Session中变量(张量值),将变量名映射至张量值

  • 参数

    • var_list:待保存、恢复变量,缺省所有
      • 变量需在tf.train.Saver实例化前创建
    • reshape:允许恢复并重新设定张量形状
    • sharded:碎片化保存至多个设备
    • max_to_keep:最多保存checkpoint数目
    • keep_checkpoint_every_n_hours:checkpoint有效时间
    • restore_sequentially:各设备中顺序恢复变量,可以 减少内存消耗
  • 成员

    • last_checkpoints:最近保存checkpoints

保存Session

1
2
3
4
5
6
7
8
9
10
def Saver.save(self,
sess,
save_path,
global_step=None/str,
latest_filename=None("checkpoint")/str,
meta_graph_suffix="meta",
write_meta_graph=True,
write_state=True
) -> str(path):
pass
  • 用途:保存Session,要求变量已初始化

  • 参数

    • global_step:添加至save_path以区别不同步骤
    • latest_filename:checkpoint文件名
    • meta_graph_suffix:MetaGraphDef文件名后缀

恢复Session

1
2
def Saver.restore(sess, save_path(str)):
pass
  • 用途:从save_path指明的路径中恢复模型
  • 模型路径可以通过Saver.last_checkpoints属性、 tf.train.get_checkpoint_state()函数获得

tf.train.get_checkpoint_state

1
2
3
4
5
def tf.train.get_checkpoint_state(
checkpoint_dir(str),
latest_filename=None
):
pass
  • 用途:获取指定checkpoint目录下checkpoint状态
    • 需要图结构已经建好、Session开启
    • 恢复模型得到的变量无需初始化
1
2
3
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
saver.restore(ckpt.model_checkpoint_path)
saver.restore(ckpt.all_model_checkpoint_paths[-1])

Graph Saver

tf.train.write_graph

1
2
3
4
5
6
def tf.train.write_graph(
graph_or_graph_def: tf.Graph,
logdir: str,
name: str,
as_text=True
)
  • 用途:存储图至文件中

  • 参数

    • as_text:以ASCII方式写入文件

Summary Saver

tf.summary.FileWriter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class tf.summary.FileWriter:
def __init__(self,
?path=str,
graph=tf.Graph
)

# 添加summary记录
def add_summary(self,
summary: OP,
global_step
):
pass

# 关闭`log`记录
def close(self):
pass
  • 用途:创建FileWriter对象用于记录log

    • 存储图到文件夹中,文件名由TF自行生成
    • 可通过TensorBoard组件查看生成的event log文件
  • 说明

    • 一般在图定义完成后、Session执行前创建FileWriter 对象,Session结束后关闭

实例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
 # 创建自定义summary
with tf.name_scope("summaries"):
tf.summary.scalar("loss", self.loss)
tf.summary.scalar("accuracy", self.accuracy)
tf.summary.histogram("histogram loss", self.loss)
summary_op = tf.summary.merge_all()

saver = tf.train.Saver()

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())

# 从checkpoint中恢复Session
ckpt = tf.train.get_check_state(os.path.dirname("checkpoint_dir"))
if ckpt and ckpt.model_check_path:
saver.restore(sess, ckpt.mode_checkpoint_path)

# summary存储图
writer = tf.summary.FileWriter("./graphs", sess.graph)
for index in range(10000):
loas_batch, _, summary = session.run([loss, optimizer, summary_op])
writer.add_summary(summary, global_step=index)

if (index + 1) % 1000 = 0:
saver.save(sess, "checkpoint_dir", index)

# 关闭`FileWriter`,生成event log文件
write.close()

数据持久化

DBM

DBM文件:python库中数据库管理的标准工具之一

  • 实现了数据的随机访问
    • 可以使用键访问存储的文本字符串
  • DBM文件有多个实现
    • python标准库中dbm/dbm.py

使用

  • 使用DBM文件和使用内存字典类型非常相似
    • 支持通过键抓取、测试、删除对象

pickle

  • 将内存中的python对象转换为序列化的字节流,可以写入任何 输出流中
  • 根据序列化的字节流重新构建原来内存中的对象
  • 感觉上比较像XML的表示方式,但是是python专用
1
2
3
4
5
6
7
import pickle
dbfile = open("people.pkl", "wb")
pickle.dump(db, dbfile)
dbfile.close()
dbfile = open("people.pkl", "rb")
db = pickle.load(dbfile)
dbfile.close()
  • 不支持pickle序列化的数据类型
    • 套接字

shelves

  • 就像能必须打开着的、存储持久化对象的词典
    • 自动处理内容、文件之间的映射
    • 在程序退出时进行持久化,自动分隔存储记录,只获取、 更新被访问、修改的记录
  • 使用像一堆只存储一条记录的pickle文件
    • 会自动在当前目录下创建许多文件
1
2
3
4
5
6
import shelves
db = shelves.open("people-shelves", writeback=True)
// `writeback`:载入所有数据进内存缓存,关闭时再写回,
// 能避免手动写回,但会消耗内存,关闭变慢
db["bob"] = "Bob"
db.close()

copyreg

marshal

sqlite3