TensorFlow Python IO接口
TFRecord
TFRecord格式:序列化的tf.train.Example
protbuf对象
1 | class tf.python_io.TFRecordWriter: |
示例
转换、写入TFRecord
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16# 创建写入文件
writer = tf.python_io.TFRecord(out_file)
shape, binary_image = get_image_binary(image_file)
# 创建Features对象
featurs = tf.train.Features(
feature = {
"label": tf.train.Feature(int64_list=tf.train.Int64List(label)),
"shape": tf.train.Feature(bytes_list=tf.train.BytesList(shape)),
"image": tf.train.Feature(bytes_list=tf.train.BytesList(binary_image))
}
)
# 创建包含以上特征的示例对象
sample = tf.train.Example(features=Features)
# 写入文件
writer.write(sample.SerializeToString())
writer.close()读取TFRecord
1
2
3
4
5
6
7
8
9
10
11dataset = tf.data.TFRecordDataset(tfrecord_files)
dataset = dataset.map(_parse_function)
def _parse_function(tf_record_serialized):
features = {
"labels": tf.FixedLenFeature([], tf.int64),
"shape": tf.FixedLenFeature([], tf.string),
"image": tf.FixedLenFeature([], tf.string)
}
parsed_features = tf.parse_single_example(tfrecord_serialized, features)
return parsed_features["label"], parsed_features["shape"],
parsed_features["image"]