要批量读取图片,您可以使用TensorFlow中的tf.data.Dataset
API。以下是一个简单的示例代码,演示了如何批量读取图片:
import tensorflow as tf # 创建一个包含图片文件路径的列表 file_paths = ["image1.jpg", "image2.jpg", "image3.jpg", ...] # 创建一个Dataset对象,将文件路径列表转换为Dataset dataset = tf.data.Dataset.from_tensor_slices(file_paths) # 定义一个函数,用于读取和解码图片 def load_and_preprocess_image(file_path): image = tf.io.read_file(file_path) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.resize(image, [224, 224]) # 调整图片大小 image = tf.cast(image, tf.float32) / 255.0 # 将像素值归一化到[0, 1] return image # 使用map函数将load_and_preprocess_image函数应用到Dataset中的每个元素 dataset = dataset.map(load_and_preprocess_image) # 设置batch大小,将数据集分批次读取 batch_size = 32 dataset = dataset.batch(batch_size) # 创建一个迭代器,用于遍历数据集 iterator = iter(dataset) # 读取一个batch的图片数据 images = next(iterator) # 输出shape print(images.shape)
在这个示例中,首先创建一个包含图片文件路径的列表file_paths
,然后将这个列表转换为tf.data.Dataset
对象。定义一个函数load_and_preprocess_image
用于读取和处理图片数据。接着,使用map
函数将load_and_preprocess_image
函数应用到数据集中的每个元素,然后使用batch
函数将数据集分批次读取。最后,创建一个迭代器并使用next
函数读取一个batch的图片数据。