发布于 5年前

TensorFlow 2.0 tf.dataset类的使用

这个笔记主要是TensorFlow 2.0的tf.dataset接口的使用。下面的示例会把numpy array的数据写入到TFRecord文件中,以及从TFRecord文件中读取数据到numpy array。

安装

可以参考官网的https://www.tensorflow.org/install教程来安装。

安装完成后,检查安装的TensorFlow的版本:

import tensorflow as tf
print(tf.__version__)

TensorFlow Dataset的使用

在TensorFlow 2.0中,向网络灌输数据的最好方法是使用tf.dataset类,dataset本身就是一个迭代器,所以可以使用for循环的方法来迭代dataset里的数据。

1、使用numpy array来创建一个dataset:

import numpy as np
np.random.seed(0)
data = np.random.randn(256, 8, 8, 3)
dataset = tf.data.Dataset.from_tensor_slices(data)
print(dataset)
...
<TensorSliceDataset shapes: (8, 8, 3), types: tf.float64>

可以通过print()方法输出dataset,可以看到dataset的shap。

通常,第一维度的数据表示训练样本的数量。DataSet可以生产任何大小的batch size,但默认情况下,batch size的值为1, 也即是生成各自独立的训练样本。

2、迭代dataset

使用for循环可以dataset做迭代,如果想获取每个批次的数据,可以使用Python的enumerate,或者使用Dataset自身的方法enumerate(),迭代示例如下:

for i, batch in enumerate(dataset):
  if i == 255 or i == 256:
  print(i, batch.shape)
...
255 (8, 8, 3)
...
for i, batch in dataset.enumerate():
  if i == 255 or i == 256:
  print(i, batch.shape)
  print(i.numpy(), batch.shape)
...
tf.Tensor(255, shape=(), dtype=int64) (8, 8, 3) 255 (8, 8, 3)
...

可以看到,使用dataset.enumerate()内置方法,返回的的一个值是一个Tensor(张量)。

3、重复迭代dataset

如果需要重复多次对dataset进行迭代,可以使用dataset的内置方法repeat()。示例:

for i, batch in dataset.repeat(2).enumerate():
  if i == 255 or i == 256:
  print(i.numpy(), batch.shape)
...
255 (8, 8, 3)
256 (8, 8, 3)

4、使用take()获取指定数量大小的样本数

如果不想使用整个数据集,可以使用take()方法来获取指定数量的数据集:

for batch in dataset.take(3):
  print(batch.shape)
...
(8, 8, 3)
(8, 8, 3)
(8, 8, 3)

5、设置batch size

默认情况下,dataset是以batch size为1来迭代,可以使用batch()方法设置batch size的大小。

dataset = dataset.batch(16)
for batch in dataset.take(3):
  print(batch.shape)
...
(16, 8, 8, 3)
(16, 8, 8, 3)
(16, 8, 8, 3)

设置了batch size为16

6、打乱数据集

shuffle()方法可以用来打乱数据,其中shuffle()方法会接收一个buffer_size的参数,这个参数作为一个每一次打乱数据的缓存区,也即是每次去出buffer_size大小的数据进行打乱。如果想完全打乱整个数据集,buffer_size需要设置为整个数据集的大小。

示例:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(19))
for batch in dataset.batch(5):
  print(batch)
...
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
...
for batch in dataset.shuffle(5).batch(5):
  print(batch)
...
tf.Tensor([2 5 0 4 1], shape=(5,), dtype=int64)
tf.Tensor([ 6  9  3 12 10], shape=(5,), dtype=int64)
tf.Tensor([13  8 15 17 11], shape=(5,), dtype=int64)
tf.Tensor([18 16 14  7], shape=(4,), dtype=int64)

可以看到shuffle()的buffer_size为5,batch size也是5,每次取出5个数据,并进行打乱。打乱后,每个批次的数据就不是原来按顺序的了。

需要注意的是,如果把shuffle()方法和batch()方法调转,会导致的结果是对批次打乱,而不是对数据集里的数据打乱。

for batch in dataset.batch(5).shuffle(5):
  print(batch)
...
tf.Tensor([5 6 7 8 9], shape=(5,), dtype=int64)
tf.Tensor([15 16 17 18], shape=(4,), dtype=int64)
tf.Tensor([0 1 2 3 4], shape=(5,), dtype=int64)
tf.Tensor([10 11 12 13 14], shape=(5,), dtype=int64)

7、转换数据

如果我们想对导入的数据做预处理,可以使用map方法。

def tranform(data):
  mean = tf.reduce_mean(data)
  return data - mean
for batch in dataset.shuffle(5).batch(5).map(tranform):
  print(batch)
...
tf.Tensor([ 2  3 -1  0 -2], shape=(5,), dtype=int64)
tf.Tensor([-2 -5  2  3  4], shape=(5,), dtype=int64)
tf.Tensor([-1  1  2 -5  3], shape=(5,), dtype=int64)
tf.Tensor([ 3 -3  7 -4], shape=(4,), dtype=int64)

8、预取指定大小的batch来做训练

通常,读取和处理dataset的数据会很耗时,即耗CPU时间,为了让GPU不出现太多空闲,可以使用prefetch()方法预取一定数据的batch来做训练。

dataset.shuffle(5).batch(5).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

其中,把buffer_size设置为tf.data.experimental.AUTOTUNE,意思是让TensorFlow自己找到一个合适的最优的buffer_size。

©2020 edoou.com   京ICP备16001874号-3