Using Huge, Heterogenous Datasets in TensorFlow

When using TensorFlow, the size of the dataset can be so big sometimes such that it cannot be stored in the main memory completely. TensorFlow has provided the tf.data.Dataset API to reduce memory footprint and improve the efficiency when working with big datasets. However, the examples in the documentation are built around common data types such as text, image, etc. It is unclear how to adapt these approaches to other types of huge custom datasets. In this post, I discuss a method that I developed for huge datasets containing heterogenous data types (based on https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset). This technique allows one to store the huge dataset on hard drive, which reduces memory consumption. It also allows one to use the dataset APIs for efficient dataset transformations and multiple-GPU training.

(Examples are tested on Python 3.8+TensorFlow 2.8.0)

A huge, heterogenous dataset

A heterogeneous dataset is defined as follows:

  1. Each sample in the dataset is a key-value pair (a dict object in Python)
  2. The keys in each sample are the same across the entire dataset
  3. For each key, the corresponding value is an array (or a scalar, which is an array with size 1)
  4. The data type of the arrays is the same across samples for a given key
  5. For a given key, the arrays should have the same amount of dimensions
  6. For a given key, the shape of the last array dimension should be the same

Despite having 6 conditions, I think most datasets can be represented in this way. An example of a sample from such a dataset is:

{
    'thumbnail': np.zeros((100, 100, 3), np.uint8),
    'waveform': np.zeros((100000, 2), np.int16),
    'spectrum': np.zeros((22050, 2), np.float32),
    'label': 0
}

This can be thought of as an example of information extracted from a music audio file:

The dataset is “heterogeneous” because it contains many fields of different data types.

In order to proceed, one should preprocess the dataset so that data samples can be represented as a list of such dict objects. Note that it is not required to transform all data samples at once, the dict object can be obtained from a generator.

Converting the dataset to TensorFlow format

In the following example, we convert the data sample dict objects into a TFRecordDataset dataset.

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # consider suppressing the use of GPU
import tensorflow as tf
import numpy as np
import json

# wrapper for binary features (in TFRecordDataset)
def bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# wrapper for int64 features (in TFRecordDataset)
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# the input sample is a `dict` object
def transform_sample(sample):
    
    # this dictionary is used to store the data type and shape information of each key
    key_info = dict()
    for key, val in sample.items():
        if isinstance(val, np.ndarray):
            # generic handler for np.ndarray
            key_info[key] = dict(dtype=repr(val.dtype), shape=val.shape)
        elif isinstance(val, int):
            # special handler for int64 type
            key_info[key] = '@@int'
        else:
            raise RuntimeError('unknown data type')

    # this dictionary is the one containing the data sample
    tf_sample = dict()
    for key, val in sample.items():
        if isinstance(val, np.ndarray):
            # generic handler for np.ndarray
            tensor_val = tf.convert_to_tensor(val)
            tensor_bin = tf.io.serialize_tensor(tensor_val)
            tf_sample[key] = bytes_feature(tensor_bin)
        elif isinstance(val, int):
            # special handler for int64 type
            tf_sample[key] = int64_feature(val)
        else:
            raise RuntimeError('unknown data type')
    
    return tf.train.Example(features=tf.train.Features(feature=tf_sample)), key_info


samples = [] # a list of samples, or a generator of samples

key_info = None # used to keep track of the data type and shape information of each key
                # (only need the last one since they are the same across keys)


num_samples = 0 # keep track of the number of samples
# the samples are serialized and written to this file
with tf.io.TFRecordWriter('dataset.tfrecords') as writer:
    # the dataset building loop
    for sample in samples:
        try:
            tf_sample, key_info = transform_sample(sample)
            writer.write(tf_sample.SerializeToString())
            num_samples += 1
        except Exception as e:
            print(f'an error occurred for: {repr(e)}')

# build and save dataset metadata
dataset_metadata = {
    'num_samples' : num_samples,
    'key_info' : key_info 
}
with open('dataset_metadata.json', 'w') as outfile:
    json.dump(dataset_metadata, outfile, indent=2)

Additional comments:

Using the dataset

import os
import tensorflow as tf
import numpy as np
import json
import re

# load back dataset metadata
with open('dataset_metadata.json', 'r') as infile:
    dataset_metadata = json.load(infile)

# compute number of training steps
batch_size = 4
num_train_steps = int(np.ceil(dataset_metadata['num_samples']/batch_size))

# generate a dataset description for deserialization
dataset_description = dict()
for key, val in dataset_metadata['key_info'].items():
    if val == '@@int':
        dataset_description[key] = tf.io.FixedLenFeature([], tf.int64)
    else:
        dataset_description[key] = tf.io.FixedLenFeature([], tf.string)

# define the first transformation to deserialize the dataset
def parse_dataset(raw_example):
    parsed = tf.io.parse_single_example(raw_example, dataset_description)
    for key, val in dataset_metadata['key_info'].items():
        # this step is only needed for non-integer type
        if val != '@@int':
            s_res = re.search("'(.*?)'", val['dtype']) # search for the data type
            type_str = s_res.group(1)
            parsed[key] = tf.io.parse_tensor(parsed[key], getattr(tf, type_str))
    return parsed

# define the second transformation to split a sample into the sample and its label
def export_sample_and_label(x):
    return x, x['label']

# define the third to specify the shapes of the keys
# this is mandatory, otherwise errors may occur in training
# see https://github.com/tensorflow/tensorflow/issues/32912#issuecomment-550363802
def fix_shape(x, y):
    for key, val in dataset_metadata['key_info'].items():
        if val != '@@int':
            x[key].set_shape(val['shape'])
    return x, y

# load the dataset back
# if there are multiple files (e.g., generated by distributed dataset generation), 
# specify them in the `filenames` list 
dataset = tf.data.TFRecordDataset(filenames=['dataset.tfrecords']).map(parse_dataset)\
                                .map(export_sample_and_label)\
                                .map(fix_shape)

train_ds = dataset.batch(batch_size).prefetch(batch_size).repeat()

'''
what `train_ds` looks like:
<RepeatDataset element_spec=({'label': TensorSpec(shape=(None,), dtype=tf.int64, name=None), 
    'spectrum': TensorSpec(shape=(None, None, 2), dtype=tf.float32, name=None), 
    'thumbnail': TensorSpec(shape=(None, None, None, 3), dtype=tf.uint8, name=None), 
    'waveform': TensorSpec(shape=(None, None, 2), dtype=tf.int16, name=None)}, 
    TensorSpec(shape=(None,), dtype=tf.int64, name=None))>
'''

# define a custom model
class MyModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.proj_a = tf.keras.layers.Dense(10)
        self.proj_b = tf.keras.layers.Dense(10)
        self.proj_c = tf.keras.layers.Dense(10)
        self.last_layer = tf.keras.layers.Dense(1, activation='sigmoid')
        

    def call(self, x, training=False):
        # do some dummy operations to each field
        a = tf.cast(x['thumbnail'], tf.float32)
        b = tf.cast(x['waveform'], tf.float32)
        c = x['spectrum']
        
        value = [
            self.proj_a(a[:, :, 0, 0]),
            self.proj_b(b[:, :, 0]),
            self.proj_c(c[:, :, 0])
        ]
        
        value = tf.concat(value, axis=1)
        
        return self.last_layer(value)


# initialize the dummy model and start training
model = MyModel()
model.compile(loss='binary_crossentropy', optimizer='Adam')
model.fit(train_ds, epochs=2,steps_per_epoch=num_train_steps)

Additional comments:

Complete dummy example Python files

dataset generation

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '' # consider suppressing the use of GPU
import tensorflow as tf
import numpy as np
import json

# wrapper for binary features (in TFRecordDataset)
def bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# wrapper for int64 features (in TFRecordDataset)
def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# the input sample is a `dict` object
def transform_sample(sample):
    
    # this dictionary is used to store the data type and shape information of each key
    key_info = dict()
    for key, val in sample.items():
        if isinstance(val, np.ndarray):
            # generic handler for np.ndarray
            key_info[key] = dict(dtype=repr(val.dtype), shape=val.shape)
        elif isinstance(val, int):
            # special handler for int64 type
            key_info[key] = '@@int'
        else:
            raise RuntimeError('unknown data type')

    # this dictionary is the one containing the data sample
    tf_sample = dict()
    for key, val in sample.items():
        if isinstance(val, np.ndarray):
            # generic handler for np.ndarray
            tensor_val = tf.convert_to_tensor(val)
            tensor_bin = tf.io.serialize_tensor(tensor_val)
            tf_sample[key] = bytes_feature(tensor_bin)
        elif isinstance(val, int):
            # special handler for int64 type
            tf_sample[key] = int64_feature(val)
        else:
            raise RuntimeError('unknown data type')
    
    return tf.train.Example(features=tf.train.Features(feature=tf_sample)), key_info


# create some dummy samples
samples = []
for _ in range(8):
    samples.append({
        'thumbnail': np.zeros((100, 100, 3), np.uint8),
        'waveform': np.zeros((100000, 2), np.int16),
        'spectrum': np.zeros((22050, 2), np.float32),
        'label': 0
    })

key_info = None # used to keep track of the data type and shape information of each key
                # (only need the last one since they are the same across keys)


num_samples = 0 # keep track of the number of samples
# the samples are serialized and written to this file
with tf.io.TFRecordWriter('dataset.tfrecords') as writer:
    # the dataset building loop
    for sample in samples:
        try:
            tf_sample, key_info = transform_sample(sample)
            writer.write(tf_sample.SerializeToString())
            num_samples += 1
        except Exception as e:
            print(f'an error occurred for: {repr(e)}')

# build and save dataset metadata
dataset_metadata = {
    'num_samples' : num_samples,
    'key_info' : key_info 
}
with open('dataset_metadata.json', 'w') as outfile:
    json.dump(dataset_metadata, outfile, indent=2)

Dataset testing

import os
import tensorflow as tf
import numpy as np
import json
import re

# load back dataset metadata
with open('dataset_metadata.json', 'r') as infile:
    dataset_metadata = json.load(infile)

# compute number of training steps
batch_size = 4
num_train_steps = int(np.ceil(dataset_metadata['num_samples']/batch_size))

# generate a dataset description for deserialization
dataset_description = dict()
for key, val in dataset_metadata['key_info'].items():
    if val == '@@int':
        dataset_description[key] = tf.io.FixedLenFeature([], tf.int64)
    else:
        dataset_description[key] = tf.io.FixedLenFeature([], tf.string)

# define the first transformation to deserialize the dataset
def parse_dataset(raw_example):
    parsed = tf.io.parse_single_example(raw_example, dataset_description)
    for key, val in dataset_metadata['key_info'].items():
        # this step is only needed for non-integer type
        if val != '@@int':
            s_res = re.search("'(.*?)'", val['dtype']) # search for the data type
            type_str = s_res.group(1)
            parsed[key] = tf.io.parse_tensor(parsed[key], getattr(tf, type_str))
    return parsed

# define the second transformation to split a sample into the sample and its label
def export_sample_and_label(x):
    return x, x['label']

# define the third to specify the shapes of the keys
# this is mandatory, otherwise errors may occur in training
# see https://github.com/tensorflow/tensorflow/issues/32912#issuecomment-550363802
def fix_shape(x, y):
    for key, val in dataset_metadata['key_info'].items():
        if val != '@@int':
            x[key].set_shape(val['shape'])
    return x, y

# load the dataset back
# if there are multiple files (e.g., generated by distributed dataset generation), 
# specify them in the `filenames` list 
dataset = tf.data.TFRecordDataset(filenames=['dataset.tfrecords']).map(parse_dataset)\
                                .map(export_sample_and_label)\
                                .map(fix_shape)

train_ds = dataset.batch(batch_size).prefetch(batch_size).repeat()

# define a custom model
class MyModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        self.proj_a = tf.keras.layers.Dense(10)
        self.proj_b = tf.keras.layers.Dense(10)
        self.proj_c = tf.keras.layers.Dense(10)
        self.last_layer = tf.keras.layers.Dense(1, activation='sigmoid')
        

    def call(self, x, training=False):
        # do some dummy operations to each field
        a = tf.cast(x['thumbnail'], tf.float32)
        b = tf.cast(x['waveform'], tf.float32)
        c = x['spectrum']
        
        value = [
            self.proj_a(a[:, :, 0, 0]),
            self.proj_b(b[:, :, 0]),
            self.proj_c(c[:, :, 0])
        ]
        
        value = tf.concat(value, axis=1)
        
        return self.last_layer(value)


# initialize the dummy model and start training
model = MyModel()
model.compile(loss='binary_crossentropy', optimizer='Adam')
model.fit(train_ds, epochs=2,steps_per_epoch=num_train_steps)