杰瑞科技汇

python 读取tfrecord

为什么使用 TFRecord?

在直接使用之前,理解为什么它很重要:

  • 高效存储:TFRecord 是一种二进制格式,将数据序列化并打包成大文件,这比存储成数千个小文件(如 .jpg, .png)更高效,尤其是在分布式训练中,可以减少 I/O 操作和网络开销。
  • 跨平台兼容:可以在任何支持 TensorFlow 的平台上读写,包括 Python、C++、Go 等。
  • 优化读取性能tf.data.TFRecordDataset 可以高效地流式读取 TFRecord 文件,非常适合与 TensorFlow 的数据输入管道(tf.data API)无缝集成,实现数据的并行加载和预处理。
  • 存储多种数据类型:不仅可以存储图像,还可以存储标签、元数据(如文件名、边界框坐标)等任何可以序列化的数据。

核心概念:tf.train.Exampletf.train.Feature

TFRecord 文件中的每一条记录都是一个序列化的 tf.train.Example 协议缓冲区,而一个 Example 对象由一系列的“特性”(Feature)组成。

  • tf.train.Example: 可以看作是一个字典,键是字符串(特性名),值是 tf.train.Feature
  • tf.train.Feature: 这是存储实际数据的地方,为了统一处理,所有数据类型都必须被转换成以下三种类型之一:
    1. tf.train.BytesList: 用于存储文本、二进制数据(如 JPEG 图像编码后的字节)。
      • bytes_list = feature bytes_list { value: "image_data_in_bytes" }
    2. tf.train.FloatList: 用于存储浮点数。
      • float_list = feature float_list { value: [1.0, 2.5, 3.14] }
    3. tf.train.Int64List: 用于存储整数和布尔值。
      • int64_list = feature int64_list { value: [100, 200] }

这个设计确保了无论你的原始数据是什么格式,都可以被转换成一种标准化的结构进行存储和读取。


读取 TFRecord 文件的详细步骤

我们将使用 tf.data.TFRecordDataset,这是最推荐、最现代的方式。

步骤 1:准备一个 TFRecord 文件

如果你没有现成的 TFRecord 文件,可以先创建一个,这里我们创建一个包含图像数据和标签的简单示例。

import tensorflow as tf
import numpy as np
from PIL import Image
import io
# 1. 准备一些虚拟数据
# 创建一个随机的 RGB 图像 (32x32x3)
image_data = np.random.randint(0, 256, (32, 32, 3), dtype=np.uint8)
# 将 numpy 数组转换为 PIL 图像,然后编码为 JPEG 格式的字节
image_bytes = io.BytesIO()
Image.fromarray(image_data).save(image_bytes, format='JPEG')
image_bytes = image_bytes.getvalue()
label = 42  # 示例标签
# 2. 定义将数据转换为 Example 的函数
def create_example(image_bytes, label):
    feature = {
        'image/encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_bytes])),
        'image/label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
    }
    return tf.train.Example(features=tf.train.Features(feature=feature))
# 3. 写入 TFRecord 文件
output_path = 'example.tfrecord'
with tf.io.TFRecordWriter(output_path) as writer:
    # 写入 10 条记录
    for i in range(10):
        # 每次循环创建一个新的 Example
        example = create_example(image_bytes, label + i)
        # 序列化 Example 并写入文件
        writer.write(example.SerializeToString())
print(f"TFRecord 文件已创建: {output_path}")

步骤 2:定义解析函数

读取时,你需要一个函数来将二进制数据解码回原始的 tf.train.Example,然后再从 Example 中提取出 Feature

# 定义解析函数
def parse_tfrecord(example_proto):
    # 定义每个 Feature 的名称和类型
    # 注意:这里我们使用 tf.io.FixedLenFeature,因为我们知道每个图像和标签的长度是固定的
    # 对于可变长度的数据(如文本序列),可以使用 tf.io.VarLenFeature
    features_description = {
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/label': tf.io.FixedLenFeature([], tf.int64),
    }
    # 使用 tf.io.parse_single_example 解析单个示例
    # 返回一个字典,键与 features_description 中的键相同
    parsed_features = tf.io.parse_single_example(example_proto, features_description)
    # 提取并解码数据
    image = tf.image.decode_jpeg(parsed_features['image/encoded'], channels=3)
    label = parsed_features['image/label']
    return image, label

步骤 3:创建 TFRecordDataset 并应用解析函数

# 创建数据集
# TFRecordDataset 直接从文件路径读取
raw_dataset = tf.data.TFRecordDataset(['example.tfrecord'])
# 使用 map 应用解析函数
# num_parallel_calls=tf.data.AUTOTUNE 可以并行处理,提高效率
parsed_dataset = raw_dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
# 查看数据集中的前几条数据
for image, label in parsed_dataset.take(3):
    print("标签:", label.numpy())
    print("图像张量的形状:", image.shape)
    # 你可以使用 matplotlib 来可视化图像
    # import matplotlib.pyplot as plt
    # plt.imshow(image.numpy())
    # plt.title(f"Label: {label.numpy()}")
    # plt.show()

完整代码示例与最佳实践

下面是一个更完整的脚本,它包含了创建和读取两个部分,并展示了如何构建一个完整的数据输入管道。

import tensorflow as tf
import numpy as np
import os
# --- 1. 创建 TFRecord 文件 (如果不存在) ---
TFRECORD_FILE = 'my_data.tfrecord'
if not os.path.exists(TFRECORD_FILE):
    print("创建 TFRecord 文件...")
    # 模拟数据
    num_samples = 100
    images = np.random.rand(num_samples, 64, 64, 3).astype(np.float32) # 模拟浮点图像
    labels = np.random.randint(0, 10, size=num_samples) # 0-9 的整数标签
    metadata = [f"sample_{i}" for i in range(num_samples)] # 字符串元数据
    def serialize_example(image, label, metadata):
        feature = {
            'image': tf.train.Feature(float_list=tf.train.FloatList(value=image.flatten())),
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),
            'metadata': tf.train.Feature(bytes_list=tf.train.BytesList(value=[metadata.encode('utf-8')])),
        }
        example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
        return example_proto.SerializeToString()
    with tf.io.TFRecordWriter(TFRECORD_FILE) as writer:
        for i in range(num_samples):
            example = serialize_example(images[i], labels[i], metadata[i])
            writer.write(example)
    print("创建完成。")
# --- 2. 读取 TFRecord 文件 ---
# 定义解析函数
def parse_tfrecord_function(example_proto):
    # 定义每个 Feature 的描述
    features_description = {
        'image': tf.io.FixedLenFeature([64 * 64 * 3], tf.float32),
        'label': tf.io.FixedLenFeature([], tf.int64),
        'metadata': tf.io.FixedLenFeature([], tf.string),
    }
    # 解析单个示例
    parsed_features = tf.io.parse_single_example(example_proto, features_description)
    # 重塑图像数据
    image = tf.reshape(parsed_features['image'], [64, 64, 3])
    label = parsed_features['label']
    metadata = parsed_features['metadata']
    return image, label, metadata
# 创建 TFRecordDataset
raw_dataset = tf.data.TFRecordDataset([TFRECORD_FILE])
# 应用解析函数
parsed_dataset = raw_dataset.map(parse_tfrecord_function, num_parallel_calls=tf.data.AUTOTUNE)
# --- 3. 构建优化的数据输入管道 (tf.data best practices) ---
# 1. 打乱数据
SHUFFLE_BUFFER_SIZE = 100
shuffled_dataset = parsed_dataset.shuffle(SHUFFLE_BUFFER_SIZE)
# 2. 批处理
BATCH_SIZE = 16
batched_dataset = shuffled_dataset.batch(BATCH_SIZE)
# 3. 预取 (Prefetching)
# 在 GPU 处理当前批次的数据时,CPU 可以在后台准备下一批次的数据
prefetched_dataset = batched_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
# --- 4. 遍历并验证数据 ---
print("\n开始从数据集中读取数据...")
for batch_images, batch_labels, batch_metadata in prefetched_dataset.take(2): # 只取前两个批次
    print("\n--- 新批次 ---")
    print(f"批次图像张量形状: {batch_images.shape}")
    print(f"批次标签: {batch_labels.numpy()}")
    print(f"批次元数据: {batch_metadata.numpy()}")
    print(f"第一个样本的元数据: {batch_metadata[0].numpy().decode('utf-8')}")
print("\n读取和管道构建完成。")

处理不同类型的数据

  • 图像: 如上所示,使用 tf.image.decode_jpegtf.image.decode_png
  • 文本: 文本字符串需要先编码为字节,读取后再解码。
    • 写入: tf.train.Feature(bytes_list=tf.train.BytesList(value=[text.encode('utf-8')]))
    • 读取: tf.strings.decode(parsed_features['text'], 'utf-8')
  • 可变长度数据(如序列): 使用 tf.io.VarLenFeature
    • 写入: tf.train.Feature(int64_list=tf.train.Int64List(value=[1, 2, 3]))
    • 读取: parsed_features['sequence'] 会返回一个 SparseTensor,可以使用 tf.sparse.to_dense 将其转换为密集张量。

读取 TFRecord 文件的标准流程是:

  1. 创建 tf.data.TFRecordDataset:传入 TFRecord 文件路径列表。
  2. 定义 parse_... 函数:使用 tf.io.parse_single_example 和一个描述字典来解码二进制数据。
  3. 使用 dataset.map():将解析函数应用到数据集中的每一条记录。
  4. 构建管道:按照 shuffle -> batch -> prefetch 的顺序优化数据流,以最大化训练性能。

掌握 TFRecord 的读写是进行大规模深度学习项目的重要技能,它能让你的数据加载效率大大提升。

分享:
扫描分享到社交APP
上一篇
下一篇